mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-02 18:02:23 +00:00
[red-knot] Extract red_knot_python_semantic
crate (#11926)
This commit is contained in:
parent
ed948eaefb
commit
2dfbf118d7
23 changed files with 125 additions and 94 deletions
36
crates/red_knot_python_semantic/Cargo.toml
Normal file
36
crates/red_knot_python_semantic/Cargo.toml
Normal file
|
@ -0,0 +1,36 @@
|
|||
[package]
|
||||
name = "red_knot_python_semantic"
|
||||
version = "0.0.0"
|
||||
publish = false
|
||||
authors = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
rust-version = { workspace = true }
|
||||
homepage = { workspace = true }
|
||||
documentation = { workspace = true }
|
||||
repository = { workspace = true }
|
||||
license = { workspace = true }
|
||||
|
||||
[dependencies]
|
||||
ruff_db = { workspace = true }
|
||||
ruff_index = { workspace = true }
|
||||
ruff_python_ast = { workspace = true }
|
||||
ruff_python_stdlib = { workspace = true }
|
||||
ruff_text_size = { workspace = true }
|
||||
|
||||
bitflags = { workspace = true }
|
||||
indexmap = { workspace = true }
|
||||
salsa = { workspace = true }
|
||||
smallvec = { workspace = true }
|
||||
smol_str = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
rustc-hash = { workspace = true }
|
||||
hashbrown = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
ruff_python_parser = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
162
crates/red_knot_python_semantic/src/ast_node_ref.rs
Normal file
162
crates/red_knot_python_semantic/src/ast_node_ref.rs
Normal file
|
@ -0,0 +1,162 @@
|
|||
use std::hash::Hash;
|
||||
use std::ops::Deref;
|
||||
|
||||
use ruff_db::parsed::ParsedModule;
|
||||
|
||||
/// Ref-counted owned reference to an AST node.
|
||||
///
|
||||
/// The type holds an owned reference to the node's ref-counted [`ParsedModule`].
|
||||
/// Holding on to the node's [`ParsedModule`] guarantees that the reference to the
|
||||
/// node must still be valid.
|
||||
///
|
||||
/// Holding on to any [`AstNodeRef`] prevents the [`ParsedModule`] from being released.
|
||||
///
|
||||
/// ## Equality
|
||||
/// Two `AstNodeRef` are considered equal if their wrapped nodes are equal.
|
||||
#[derive(Clone)]
|
||||
pub struct AstNodeRef<T> {
|
||||
/// Owned reference to the node's [`ParsedModule`].
|
||||
///
|
||||
/// The node's reference is guaranteed to remain valid as long as it's enclosing
|
||||
/// [`ParsedModule`] is alive.
|
||||
_parsed: ParsedModule,
|
||||
|
||||
/// Pointer to the referenced node.
|
||||
node: std::ptr::NonNull<T>,
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
impl<T> AstNodeRef<T> {
|
||||
/// Creates a new `AstNodeRef` that reference `node`. The `parsed` is the [`ParsedModule`] to which
|
||||
/// the `AstNodeRef` belongs.
|
||||
///
|
||||
/// ## Safety
|
||||
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the [`ParsedModule`] to
|
||||
/// which `node` belongs. It's the caller's responsibility to ensure that the invariant `node belongs to parsed` is upheld.
|
||||
|
||||
pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
|
||||
Self {
|
||||
_parsed: parsed,
|
||||
node: std::ptr::NonNull::from(node),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a reference to the wrapped node.
|
||||
pub fn node(&self) -> &T {
|
||||
// SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still alive
|
||||
// and not moved.
|
||||
unsafe { self.node.as_ref() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Deref for AstNodeRef<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.node()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::fmt::Debug for AstNodeRef<T>
|
||||
where
|
||||
T: std::fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_tuple("AstNodeRef").field(&self.node()).finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> PartialEq for AstNodeRef<T>
|
||||
where
|
||||
T: PartialEq,
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.node().eq(other.node())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Eq for AstNodeRef<T> where T: Eq {}
|
||||
|
||||
impl<T> Hash for AstNodeRef<T>
|
||||
where
|
||||
T: Hash,
|
||||
{
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.node().hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl<T> Send for AstNodeRef<T> where T: Send {}
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl<T> Sync for AstNodeRef<T> where T: Sync {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::ast_node_ref::AstNodeRef;
|
||||
use ruff_db::parsed::ParsedModule;
|
||||
use ruff_python_ast::PySourceType;
|
||||
use ruff_python_parser::parse_unchecked_source;
|
||||
|
||||
#[test]
|
||||
#[allow(unsafe_code)]
|
||||
fn equality() {
|
||||
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
|
||||
let parsed = ParsedModule::new(parsed_raw.clone());
|
||||
|
||||
let stmt = &parsed.syntax().body[0];
|
||||
|
||||
let node1 = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
|
||||
let node2 = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
|
||||
|
||||
assert_eq!(node1, node2);
|
||||
|
||||
// Compare from different trees
|
||||
let cloned = ParsedModule::new(parsed_raw);
|
||||
let stmt_cloned = &cloned.syntax().body[0];
|
||||
let cloned_node = unsafe { AstNodeRef::new(cloned.clone(), stmt_cloned) };
|
||||
|
||||
assert_eq!(node1, cloned_node);
|
||||
|
||||
let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
|
||||
let other = ParsedModule::new(other_raw);
|
||||
|
||||
let other_stmt = &other.syntax().body[0];
|
||||
let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) };
|
||||
|
||||
assert_ne!(node1, other_node);
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
#[test]
|
||||
fn inequality() {
|
||||
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
|
||||
let parsed = ParsedModule::new(parsed_raw.clone());
|
||||
|
||||
let stmt = &parsed.syntax().body[0];
|
||||
let node = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
|
||||
|
||||
let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
|
||||
let other = ParsedModule::new(other_raw);
|
||||
|
||||
let other_stmt = &other.syntax().body[0];
|
||||
let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) };
|
||||
|
||||
assert_ne!(node, other_node);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(unsafe_code)]
|
||||
fn debug() {
|
||||
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
|
||||
let parsed = ParsedModule::new(parsed_raw.clone());
|
||||
|
||||
let stmt = &parsed.syntax().body[0];
|
||||
|
||||
let stmt_node = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
|
||||
|
||||
let debug = format!("{stmt_node:?}");
|
||||
|
||||
assert_eq!(debug, format!("AstNodeRef({stmt:?})"));
|
||||
}
|
||||
}
|
265
crates/red_knot_python_semantic/src/db.rs
Normal file
265
crates/red_knot_python_semantic/src/db.rs
Normal file
|
@ -0,0 +1,265 @@
|
|||
use salsa::DbWithJar;
|
||||
|
||||
use ruff_db::{Db as SourceDb, Upcast};
|
||||
|
||||
use crate::module::resolver::{
|
||||
file_to_module, internal::ModuleNameIngredient, internal::ModuleResolverSearchPaths,
|
||||
resolve_module_query,
|
||||
};
|
||||
|
||||
use crate::semantic_index::symbol::{public_symbols_map, scopes_map, PublicSymbolId, ScopeId};
|
||||
use crate::semantic_index::{root_scope, semantic_index, symbol_table};
|
||||
use crate::types::{infer_types, public_symbol_ty};
|
||||
|
||||
#[salsa::jar(db=Db)]
|
||||
pub struct Jar(
|
||||
ModuleNameIngredient,
|
||||
ModuleResolverSearchPaths,
|
||||
ScopeId,
|
||||
PublicSymbolId,
|
||||
symbol_table,
|
||||
resolve_module_query,
|
||||
file_to_module,
|
||||
scopes_map,
|
||||
root_scope,
|
||||
semantic_index,
|
||||
infer_types,
|
||||
public_symbol_ty,
|
||||
public_symbols_map,
|
||||
);
|
||||
|
||||
/// Database giving access to semantic information about a Python program.
|
||||
pub trait Db: SourceDb + DbWithJar<Jar> + Upcast<dyn SourceDb> {}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use std::fmt::Formatter;
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use salsa::ingredient::Ingredient;
|
||||
use salsa::storage::HasIngredientsFor;
|
||||
use salsa::{AsId, DebugWithDb};
|
||||
|
||||
use ruff_db::file_system::{FileSystem, MemoryFileSystem, OsFileSystem};
|
||||
use ruff_db::vfs::Vfs;
|
||||
use ruff_db::{Db as SourceDb, Jar as SourceJar, Upcast};
|
||||
|
||||
use super::{Db, Jar};
|
||||
|
||||
#[salsa::db(Jar, SourceJar)]
|
||||
pub(crate) struct TestDb {
|
||||
storage: salsa::Storage<Self>,
|
||||
vfs: Vfs,
|
||||
file_system: TestFileSystem,
|
||||
events: std::sync::Arc<std::sync::Mutex<Vec<salsa::Event>>>,
|
||||
}
|
||||
|
||||
impl TestDb {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
storage: salsa::Storage::default(),
|
||||
file_system: TestFileSystem::Memory(MemoryFileSystem::default()),
|
||||
events: std::sync::Arc::default(),
|
||||
vfs: Vfs::with_stubbed_vendored(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the memory file system.
|
||||
///
|
||||
/// ## Panics
|
||||
/// If this test db isn't using a memory file system.
|
||||
pub(crate) fn memory_file_system(&self) -> &MemoryFileSystem {
|
||||
if let TestFileSystem::Memory(fs) = &self.file_system {
|
||||
fs
|
||||
} else {
|
||||
panic!("The test db is not using a memory file system");
|
||||
}
|
||||
}
|
||||
|
||||
/// Uses the real file system instead of the memory file system.
|
||||
///
|
||||
/// This useful for testing advanced file system features like permissions, symlinks, etc.
|
||||
///
|
||||
/// Note that any files written to the memory file system won't be copied over.
|
||||
#[allow(unused)]
|
||||
pub(crate) fn with_os_file_system(&mut self) {
|
||||
self.file_system = TestFileSystem::Os(OsFileSystem);
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub(crate) fn vfs_mut(&mut self) -> &mut Vfs {
|
||||
&mut self.vfs
|
||||
}
|
||||
|
||||
/// Takes the salsa events.
|
||||
///
|
||||
/// ## Panics
|
||||
/// If there are any pending salsa snapshots.
|
||||
pub(crate) fn take_salsa_events(&mut self) -> Vec<salsa::Event> {
|
||||
let inner = Arc::get_mut(&mut self.events).expect("no pending salsa snapshots");
|
||||
|
||||
let events = inner.get_mut().unwrap();
|
||||
std::mem::take(&mut *events)
|
||||
}
|
||||
|
||||
/// Clears the salsa events.
|
||||
///
|
||||
/// ## Panics
|
||||
/// If there are any pending salsa snapshots.
|
||||
pub(crate) fn clear_salsa_events(&mut self) {
|
||||
self.take_salsa_events();
|
||||
}
|
||||
}
|
||||
|
||||
impl SourceDb for TestDb {
|
||||
fn file_system(&self) -> &dyn FileSystem {
|
||||
match &self.file_system {
|
||||
TestFileSystem::Memory(fs) => fs,
|
||||
TestFileSystem::Os(fs) => fs,
|
||||
}
|
||||
}
|
||||
|
||||
fn vfs(&self) -> &Vfs {
|
||||
&self.vfs
|
||||
}
|
||||
}
|
||||
|
||||
impl Upcast<dyn SourceDb> for TestDb {
|
||||
fn upcast(&self) -> &(dyn SourceDb + 'static) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Db for TestDb {}
|
||||
|
||||
impl salsa::Database for TestDb {
|
||||
fn salsa_event(&self, event: salsa::Event) {
|
||||
tracing::trace!("event: {:?}", event.debug(self));
|
||||
let mut events = self.events.lock().unwrap();
|
||||
events.push(event);
|
||||
}
|
||||
}
|
||||
|
||||
impl salsa::ParallelDatabase for TestDb {
|
||||
fn snapshot(&self) -> salsa::Snapshot<Self> {
|
||||
salsa::Snapshot::new(Self {
|
||||
storage: self.storage.snapshot(),
|
||||
vfs: self.vfs.snapshot(),
|
||||
file_system: match &self.file_system {
|
||||
TestFileSystem::Memory(memory) => TestFileSystem::Memory(memory.snapshot()),
|
||||
TestFileSystem::Os(fs) => TestFileSystem::Os(fs.snapshot()),
|
||||
},
|
||||
events: self.events.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
enum TestFileSystem {
|
||||
Memory(MemoryFileSystem),
|
||||
#[allow(unused)]
|
||||
Os(OsFileSystem),
|
||||
}
|
||||
|
||||
pub(crate) fn assert_will_run_function_query<C, Db, Jar>(
|
||||
db: &Db,
|
||||
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
|
||||
key: C::Key,
|
||||
events: &[salsa::Event],
|
||||
) where
|
||||
C: salsa::function::Configuration<Jar = Jar>
|
||||
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
|
||||
Jar: HasIngredientsFor<C>,
|
||||
Db: salsa::DbWithJar<Jar>,
|
||||
C::Key: AsId,
|
||||
{
|
||||
will_run_function_query(db, to_function, key, events, true);
|
||||
}
|
||||
|
||||
pub(crate) fn assert_will_not_run_function_query<C, Db, Jar>(
|
||||
db: &Db,
|
||||
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
|
||||
key: C::Key,
|
||||
events: &[salsa::Event],
|
||||
) where
|
||||
C: salsa::function::Configuration<Jar = Jar>
|
||||
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
|
||||
Jar: HasIngredientsFor<C>,
|
||||
Db: salsa::DbWithJar<Jar>,
|
||||
C::Key: AsId,
|
||||
{
|
||||
will_run_function_query(db, to_function, key, events, false);
|
||||
}
|
||||
|
||||
fn will_run_function_query<C, Db, Jar>(
|
||||
db: &Db,
|
||||
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
|
||||
key: C::Key,
|
||||
events: &[salsa::Event],
|
||||
should_run: bool,
|
||||
) where
|
||||
C: salsa::function::Configuration<Jar = Jar>
|
||||
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
|
||||
Jar: HasIngredientsFor<C>,
|
||||
Db: salsa::DbWithJar<Jar>,
|
||||
C::Key: AsId,
|
||||
{
|
||||
let (jar, _) =
|
||||
<_ as salsa::storage::HasJar<<C as salsa::storage::IngredientsFor>::Jar>>::jar(db);
|
||||
let ingredient = jar.ingredient();
|
||||
|
||||
let function_ingredient = to_function(ingredient);
|
||||
|
||||
let ingredient_index =
|
||||
<salsa::function::FunctionIngredient<C> as Ingredient<Db>>::ingredient_index(
|
||||
function_ingredient,
|
||||
);
|
||||
|
||||
let did_run = events.iter().any(|event| {
|
||||
if let salsa::EventKind::WillExecute { database_key } = event.kind {
|
||||
database_key.ingredient_index() == ingredient_index
|
||||
&& database_key.key_index() == key.as_id()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
|
||||
if should_run && !did_run {
|
||||
panic!(
|
||||
"Expected query {:?} to run but it didn't",
|
||||
DebugIdx {
|
||||
db: PhantomData::<Db>,
|
||||
value_id: key.as_id(),
|
||||
ingredient: function_ingredient,
|
||||
}
|
||||
);
|
||||
} else if !should_run && did_run {
|
||||
panic!(
|
||||
"Expected query {:?} not to run but it did",
|
||||
DebugIdx {
|
||||
db: PhantomData::<Db>,
|
||||
value_id: key.as_id(),
|
||||
ingredient: function_ingredient,
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
struct DebugIdx<'a, I, Db>
|
||||
where
|
||||
I: Ingredient<Db>,
|
||||
{
|
||||
value_id: salsa::Id,
|
||||
ingredient: &'a I,
|
||||
db: PhantomData<Db>,
|
||||
}
|
||||
|
||||
impl<'a, I, Db> std::fmt::Debug for DebugIdx<'a, I, Db>
|
||||
where
|
||||
I: Ingredient<Db>,
|
||||
{
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
self.ingredient.fmt_index(Some(self.value_id), f)
|
||||
}
|
||||
}
|
||||
}
|
13
crates/red_knot_python_semantic/src/lib.rs
Normal file
13
crates/red_knot_python_semantic/src/lib.rs
Normal file
|
@ -0,0 +1,13 @@
|
|||
pub mod ast_node_ref;
|
||||
mod db;
|
||||
pub mod module;
|
||||
pub mod name;
|
||||
mod node_key;
|
||||
pub mod semantic_index;
|
||||
pub mod types;
|
||||
|
||||
type FxIndexSet<V> = indexmap::set::IndexSet<V, BuildHasherDefault<FxHasher>>;
|
||||
|
||||
pub use db::{Db, Jar};
|
||||
use rustc_hash::FxHasher;
|
||||
use std::hash::BuildHasherDefault;
|
10
crates/red_knot_python_semantic/src/mod.rs
Normal file
10
crates/red_knot_python_semantic/src/mod.rs
Normal file
|
@ -0,0 +1,10 @@
|
|||
use std::hash::BuildHasherDefault;
|
||||
|
||||
use rustc_hash::FxHasher;
|
||||
|
||||
pub mod ast_node_ref;
|
||||
mod node_key;
|
||||
pub mod semantic_index;
|
||||
pub mod types;
|
||||
|
||||
pub(crate) type FxIndexSet<V> = indexmap::set::IndexSet<V, BuildHasherDefault<FxHasher>>;
|
332
crates/red_knot_python_semantic/src/module.rs
Normal file
332
crates/red_knot_python_semantic/src/module.rs
Normal file
|
@ -0,0 +1,332 @@
|
|||
use std::fmt::Formatter;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use ruff_db::file_system::FileSystemPath;
|
||||
use ruff_db::vfs::{VfsFile, VfsPath};
|
||||
use ruff_python_stdlib::identifiers::is_identifier;
|
||||
|
||||
use crate::Db;
|
||||
|
||||
pub mod resolver;
|
||||
|
||||
/// A module name, e.g. `foo.bar`.
|
||||
///
|
||||
/// Always normalized to the absolute form (never a relative module name, i.e., never `.foo`).
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct ModuleName(smol_str::SmolStr);
|
||||
|
||||
impl ModuleName {
|
||||
/// Creates a new module name for `name`. Returns `Some` if `name` is a valid, absolute
|
||||
/// module name and `None` otherwise.
|
||||
///
|
||||
/// The module name is invalid if:
|
||||
///
|
||||
/// * The name is empty
|
||||
/// * The name is relative
|
||||
/// * The name ends with a `.`
|
||||
/// * The name contains a sequence of multiple dots
|
||||
/// * A component of a name (the part between two dots) isn't a valid python identifier.
|
||||
#[inline]
|
||||
pub fn new(name: &str) -> Option<Self> {
|
||||
Self::new_from_smol(smol_str::SmolStr::new(name))
|
||||
}
|
||||
|
||||
/// Creates a new module name for `name` where `name` is a static string.
|
||||
/// Returns `Some` if `name` is a valid, absolute module name and `None` otherwise.
|
||||
///
|
||||
/// The module name is invalid if:
|
||||
///
|
||||
/// * The name is empty
|
||||
/// * The name is relative
|
||||
/// * The name ends with a `.`
|
||||
/// * The name contains a sequence of multiple dots
|
||||
/// * A component of a name (the part between two dots) isn't a valid python identifier.
|
||||
///
|
||||
/// ## Examples
|
||||
///
|
||||
/// ```
|
||||
/// use red_knot_python_semantic::module::ModuleName;
|
||||
///
|
||||
/// assert_eq!(ModuleName::new_static("foo.bar").as_deref(), Some("foo.bar"));
|
||||
/// assert_eq!(ModuleName::new_static(""), None);
|
||||
/// assert_eq!(ModuleName::new_static("..foo"), None);
|
||||
/// assert_eq!(ModuleName::new_static(".foo"), None);
|
||||
/// assert_eq!(ModuleName::new_static("foo."), None);
|
||||
/// assert_eq!(ModuleName::new_static("foo..bar"), None);
|
||||
/// assert_eq!(ModuleName::new_static("2000"), None);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn new_static(name: &'static str) -> Option<Self> {
|
||||
Self::new_from_smol(smol_str::SmolStr::new_static(name))
|
||||
}
|
||||
|
||||
fn new_from_smol(name: smol_str::SmolStr) -> Option<Self> {
|
||||
if name.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if name.split('.').all(is_identifier) {
|
||||
Some(Self(name))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator over the components of the module name:
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use red_knot_python_semantic::module::ModuleName;
|
||||
///
|
||||
/// assert_eq!(ModuleName::new_static("foo.bar.baz").unwrap().components().collect::<Vec<_>>(), vec!["foo", "bar", "baz"]);
|
||||
/// ```
|
||||
pub fn components(&self) -> impl DoubleEndedIterator<Item = &str> {
|
||||
self.0.split('.')
|
||||
}
|
||||
|
||||
/// The name of this module's immediate parent, if it has a parent.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use red_knot_python_semantic::module::ModuleName;
|
||||
///
|
||||
/// assert_eq!(ModuleName::new_static("foo.bar").unwrap().parent(), Some(ModuleName::new_static("foo").unwrap()));
|
||||
/// assert_eq!(ModuleName::new_static("foo.bar.baz").unwrap().parent(), Some(ModuleName::new_static("foo.bar").unwrap()));
|
||||
/// assert_eq!(ModuleName::new_static("root").unwrap().parent(), None);
|
||||
/// ```
|
||||
pub fn parent(&self) -> Option<ModuleName> {
|
||||
let (parent, _) = self.0.rsplit_once('.')?;
|
||||
|
||||
Some(Self(smol_str::SmolStr::new(parent)))
|
||||
}
|
||||
|
||||
/// Returns `true` if the name starts with `other`.
|
||||
///
|
||||
/// This is equivalent to checking if `self` is a sub-module of `other`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use red_knot_python_semantic::module::ModuleName;
|
||||
///
|
||||
/// assert!(ModuleName::new_static("foo.bar").unwrap().starts_with(&ModuleName::new_static("foo").unwrap()));
|
||||
///
|
||||
/// assert!(!ModuleName::new_static("foo.bar").unwrap().starts_with(&ModuleName::new_static("bar").unwrap()));
|
||||
/// assert!(!ModuleName::new_static("foo_bar").unwrap().starts_with(&ModuleName::new_static("foo").unwrap()));
|
||||
/// ```
|
||||
pub fn starts_with(&self, other: &ModuleName) -> bool {
|
||||
let mut self_components = self.components();
|
||||
let other_components = other.components();
|
||||
|
||||
for other_component in other_components {
|
||||
if self_components.next() != Some(other_component) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
|
||||
fn from_relative_path(path: &FileSystemPath) -> Option<Self> {
|
||||
let path = if path.ends_with("__init__.py") || path.ends_with("__init__.pyi") {
|
||||
path.parent()?
|
||||
} else {
|
||||
path
|
||||
};
|
||||
|
||||
let name = if let Some(parent) = path.parent() {
|
||||
let mut name = String::with_capacity(path.as_str().len());
|
||||
|
||||
for component in parent.components() {
|
||||
name.push_str(component.as_os_str().to_str()?);
|
||||
name.push('.');
|
||||
}
|
||||
|
||||
// SAFETY: Unwrap is safe here or `parent` would have returned `None`.
|
||||
name.push_str(path.file_stem().unwrap());
|
||||
|
||||
smol_str::SmolStr::from(name)
|
||||
} else {
|
||||
smol_str::SmolStr::new(path.file_stem()?)
|
||||
};
|
||||
|
||||
Some(Self(name))
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ModuleName {
|
||||
type Target = str;
|
||||
|
||||
#[inline]
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<str> for ModuleName {
|
||||
fn eq(&self, other: &str) -> bool {
|
||||
self.as_str() == other
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<ModuleName> for str {
|
||||
fn eq(&self, other: &ModuleName) -> bool {
|
||||
self == other.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ModuleName {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(&self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Representation of a Python module.
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct Module {
|
||||
inner: Arc<ModuleInner>,
|
||||
}
|
||||
|
||||
impl Module {
|
||||
/// The absolute name of the module (e.g. `foo.bar`)
|
||||
pub fn name(&self) -> &ModuleName {
|
||||
&self.inner.name
|
||||
}
|
||||
|
||||
/// The file to the source code that defines this module
|
||||
pub fn file(&self) -> VfsFile {
|
||||
self.inner.file
|
||||
}
|
||||
|
||||
/// The search path from which the module was resolved.
|
||||
pub fn search_path(&self) -> &ModuleSearchPath {
|
||||
&self.inner.search_path
|
||||
}
|
||||
|
||||
/// Determine whether this module is a single-file module or a package
|
||||
pub fn kind(&self) -> ModuleKind {
|
||||
self.inner.kind
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Module {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Module")
|
||||
.field("name", &self.name())
|
||||
.field("kind", &self.kind())
|
||||
.field("file", &self.file())
|
||||
.field("search_path", &self.search_path())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl salsa::DebugWithDb<dyn Db> for Module {
|
||||
fn fmt(&self, f: &mut Formatter<'_>, db: &dyn Db) -> std::fmt::Result {
|
||||
f.debug_struct("Module")
|
||||
.field("name", &self.name())
|
||||
.field("kind", &self.kind())
|
||||
.field("file", &self.file().debug(db.upcast()))
|
||||
.field("search_path", &self.search_path())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
struct ModuleInner {
|
||||
name: ModuleName,
|
||||
kind: ModuleKind,
|
||||
search_path: ModuleSearchPath,
|
||||
file: VfsFile,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub enum ModuleKind {
|
||||
/// A single-file module (e.g. `foo.py` or `foo.pyi`)
|
||||
Module,
|
||||
|
||||
/// A python package (`foo/__init__.py` or `foo/__init__.pyi`)
|
||||
Package,
|
||||
}
|
||||
|
||||
/// A search path in which to search modules.
|
||||
/// Corresponds to a path in [`sys.path`](https://docs.python.org/3/library/sys_path_init.html) at runtime.
|
||||
///
|
||||
/// Cloning a search path is cheap because it's an `Arc`.
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct ModuleSearchPath {
|
||||
inner: Arc<ModuleSearchPathInner>,
|
||||
}
|
||||
|
||||
impl ModuleSearchPath {
|
||||
pub fn new<P>(path: P, kind: ModuleSearchPathKind) -> Self
|
||||
where
|
||||
P: Into<VfsPath>,
|
||||
{
|
||||
Self {
|
||||
inner: Arc::new(ModuleSearchPathInner {
|
||||
path: path.into(),
|
||||
kind,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine whether this is a first-party, third-party or standard-library search path
|
||||
pub fn kind(&self) -> ModuleSearchPathKind {
|
||||
self.inner.kind
|
||||
}
|
||||
|
||||
/// Return the location of the search path on the file system
|
||||
pub fn path(&self) -> &VfsPath {
|
||||
&self.inner.path
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ModuleSearchPath {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ModuleSearchPath")
|
||||
.field("path", &self.inner.path)
|
||||
.field("kind", &self.kind())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq)]
|
||||
struct ModuleSearchPathInner {
|
||||
path: VfsPath,
|
||||
kind: ModuleSearchPathKind,
|
||||
}
|
||||
|
||||
/// Enumeration of the different kinds of search paths type checkers are expected to support.
|
||||
///
|
||||
/// N.B. Although we don't implement `Ord` for this enum, they are ordered in terms of the
|
||||
/// priority that we want to give these modules when resolving them.
|
||||
/// This is roughly [the order given in the typing spec], but typeshed's stubs
|
||||
/// for the standard library are moved higher up to match Python's semantics at runtime.
|
||||
///
|
||||
/// [the order given in the typing spec]: https://typing.readthedocs.io/en/latest/spec/distributing.html#import-resolution-ordering
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub enum ModuleSearchPathKind {
|
||||
/// "Extra" paths provided by the user in a config file, env var or CLI flag.
|
||||
/// E.g. mypy's `MYPYPATH` env var, or pyright's `stubPath` configuration setting
|
||||
Extra,
|
||||
|
||||
/// Files in the project we're directly being invoked on
|
||||
FirstParty,
|
||||
|
||||
/// The `stdlib` directory of typeshed (either vendored or custom)
|
||||
StandardLibrary,
|
||||
|
||||
/// Stubs or runtime modules installed in site-packages
|
||||
SitePackagesThirdParty,
|
||||
|
||||
/// Vendored third-party stubs from typeshed
|
||||
VendoredThirdParty,
|
||||
}
|
944
crates/red_knot_python_semantic/src/module/resolver.rs
Normal file
944
crates/red_knot_python_semantic/src/module/resolver.rs
Normal file
|
@ -0,0 +1,944 @@
|
|||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use ruff_db::file_system::{FileSystem, FileSystemPath, FileSystemPathBuf};
|
||||
use ruff_db::vfs::{system_path_to_file, vfs_path_to_file, VfsFile, VfsPath};
|
||||
|
||||
use crate::module::resolver::internal::ModuleResolverSearchPaths;
|
||||
use crate::module::{
|
||||
Module, ModuleInner, ModuleKind, ModuleName, ModuleSearchPath, ModuleSearchPathKind,
|
||||
};
|
||||
use crate::Db;
|
||||
|
||||
const TYPESHED_STDLIB_DIRECTORY: &str = "stdlib";
|
||||
|
||||
/// Configures the module search paths for the module resolver.
|
||||
///
|
||||
/// Must be called before calling any other module resolution functions.
|
||||
pub fn set_module_resolution_settings(db: &mut dyn Db, config: ModuleResolutionSettings) {
|
||||
// There's no concurrency issue here because we hold a `&mut dyn Db` reference. No other
|
||||
// thread can mutate the `Db` while we're in this call, so using `try_get` to test if
|
||||
// the settings have already been set is safe.
|
||||
if let Some(existing) = ModuleResolverSearchPaths::try_get(db) {
|
||||
existing
|
||||
.set_search_paths(db)
|
||||
.to(config.into_ordered_search_paths());
|
||||
} else {
|
||||
ModuleResolverSearchPaths::new(db, config.into_ordered_search_paths());
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolves a module name to a module.
|
||||
#[tracing::instrument(level = "debug", skip(db))]
|
||||
pub fn resolve_module(db: &dyn Db, module_name: ModuleName) -> Option<Module> {
|
||||
let interned_name = internal::ModuleNameIngredient::new(db, module_name);
|
||||
|
||||
resolve_module_query(db, interned_name)
|
||||
}
|
||||
|
||||
/// Salsa query that resolves an interned [`ModuleNameIngredient`] to a module.
|
||||
///
|
||||
/// This query should not be called directly. Instead, use [`resolve_module`]. It only exists
|
||||
/// because Salsa requires the module name to be an ingredient.
|
||||
#[salsa::tracked]
|
||||
pub(crate) fn resolve_module_query(
|
||||
db: &dyn Db,
|
||||
module_name: internal::ModuleNameIngredient,
|
||||
) -> Option<Module> {
|
||||
let name = module_name.name(db);
|
||||
|
||||
let (search_path, module_file, kind) = resolve_name(db, name)?;
|
||||
|
||||
let module = Module {
|
||||
inner: Arc::new(ModuleInner {
|
||||
name: name.clone(),
|
||||
kind,
|
||||
search_path,
|
||||
file: module_file,
|
||||
}),
|
||||
};
|
||||
|
||||
Some(module)
|
||||
}
|
||||
|
||||
/// Resolves the module for the given path.
|
||||
///
|
||||
/// Returns `None` if the path is not a module locatable via `sys.path`.
|
||||
#[tracing::instrument(level = "debug", skip(db))]
|
||||
pub fn path_to_module(db: &dyn Db, path: &VfsPath) -> Option<Module> {
|
||||
// It's not entirely clear on first sight why this method calls `file_to_module` instead of
|
||||
// it being the other way round, considering that the first thing that `file_to_module` does
|
||||
// is to retrieve the file's path.
|
||||
//
|
||||
// The reason is that `file_to_module` is a tracked Salsa query and salsa queries require that
|
||||
// all arguments are Salsa ingredients (something stored in Salsa). `Path`s aren't salsa ingredients but
|
||||
// `VfsFile` is. So what we do here is to retrieve the `path`'s `VfsFile` so that we can make
|
||||
// use of Salsa's caching and invalidation.
|
||||
let file = vfs_path_to_file(db.upcast(), path)?;
|
||||
file_to_module(db, file)
|
||||
}
|
||||
|
||||
/// Resolves the module for the file with the given id.
|
||||
///
|
||||
/// Returns `None` if the file is not a module locatable via `sys.path`.
|
||||
#[salsa::tracked]
|
||||
#[tracing::instrument(level = "debug", skip(db))]
|
||||
pub fn file_to_module(db: &dyn Db, file: VfsFile) -> Option<Module> {
|
||||
let path = file.path(db.upcast());
|
||||
|
||||
let search_paths = module_search_paths(db);
|
||||
|
||||
let relative_path = search_paths
|
||||
.iter()
|
||||
.find_map(|root| match (root.path(), path) {
|
||||
(VfsPath::FileSystem(root_path), VfsPath::FileSystem(path)) => {
|
||||
let relative_path = path.strip_prefix(root_path).ok()?;
|
||||
Some(relative_path)
|
||||
}
|
||||
(VfsPath::Vendored(_), VfsPath::Vendored(_)) => {
|
||||
todo!("Add support for vendored modules")
|
||||
}
|
||||
(VfsPath::Vendored(_), VfsPath::FileSystem(_))
|
||||
| (VfsPath::FileSystem(_), VfsPath::Vendored(_)) => None,
|
||||
})?;
|
||||
|
||||
let module_name = ModuleName::from_relative_path(relative_path)?;
|
||||
|
||||
// Resolve the module name to see if Python would resolve the name to the same path.
|
||||
// If it doesn't, then that means that multiple modules have the same name in different
|
||||
// root paths, but that the module corresponding to `path` is in a lower priority search path,
|
||||
// in which case we ignore it.
|
||||
let module = resolve_module(db, module_name)?;
|
||||
|
||||
if file == module.file() {
|
||||
Some(module)
|
||||
} else {
|
||||
// This path is for a module with the same name but with a different precedence. For example:
|
||||
// ```
|
||||
// src/foo.py
|
||||
// src/foo/__init__.py
|
||||
// ```
|
||||
// The module name of `src/foo.py` is `foo`, but the module loaded by Python is `src/foo/__init__.py`.
|
||||
// That means we need to ignore `src/foo.py` even though it resolves to the same module name.
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Configures the [`ModuleSearchPath`]s that are used to resolve modules.
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
pub struct ModuleResolutionSettings {
|
||||
/// List of user-provided paths that should take first priority in the module resolution.
|
||||
/// Examples in other type checkers are mypy's MYPYPATH environment variable,
|
||||
/// or pyright's stubPath configuration setting.
|
||||
pub extra_paths: Vec<FileSystemPathBuf>,
|
||||
|
||||
/// The root of the workspace, used for finding first-party modules.
|
||||
pub workspace_root: FileSystemPathBuf,
|
||||
|
||||
/// The path to the user's `site-packages` directory, where third-party packages from ``PyPI`` are installed.
|
||||
pub site_packages: Option<FileSystemPathBuf>,
|
||||
|
||||
/// Optional path to standard-library typeshed stubs.
|
||||
/// Currently this has to be a directory that exists on disk.
|
||||
///
|
||||
/// (TODO: fall back to vendored stubs if no custom directory is provided.)
|
||||
pub custom_typeshed: Option<FileSystemPathBuf>,
|
||||
}
|
||||
|
||||
impl ModuleResolutionSettings {
|
||||
/// Implementation of PEP 561's module resolution order
|
||||
/// (with some small, deliberate, differences)
|
||||
fn into_ordered_search_paths(self) -> OrderedSearchPaths {
|
||||
let ModuleResolutionSettings {
|
||||
extra_paths,
|
||||
workspace_root,
|
||||
site_packages,
|
||||
custom_typeshed,
|
||||
} = self;
|
||||
|
||||
let mut paths: Vec<_> = extra_paths
|
||||
.into_iter()
|
||||
.map(|path| ModuleSearchPath::new(path, ModuleSearchPathKind::Extra))
|
||||
.collect();
|
||||
|
||||
paths.push(ModuleSearchPath::new(
|
||||
workspace_root,
|
||||
ModuleSearchPathKind::FirstParty,
|
||||
));
|
||||
|
||||
// TODO fallback to vendored typeshed stubs if no custom typeshed directory is provided by the user
|
||||
if let Some(custom_typeshed) = custom_typeshed {
|
||||
paths.push(ModuleSearchPath::new(
|
||||
custom_typeshed.join(TYPESHED_STDLIB_DIRECTORY),
|
||||
ModuleSearchPathKind::StandardLibrary,
|
||||
));
|
||||
}
|
||||
|
||||
// TODO vendor typeshed's third-party stubs as well as the stdlib and fallback to them as a final step
|
||||
if let Some(site_packages) = site_packages {
|
||||
paths.push(ModuleSearchPath::new(
|
||||
site_packages,
|
||||
ModuleSearchPathKind::SitePackagesThirdParty,
|
||||
));
|
||||
}
|
||||
|
||||
OrderedSearchPaths(paths)
|
||||
}
|
||||
}
|
||||
|
||||
/// A resolved module resolution order, implementing PEP 561
|
||||
/// (with some small, deliberate differences)
|
||||
#[derive(Clone, Debug, Default, Eq, PartialEq)]
|
||||
pub(crate) struct OrderedSearchPaths(Vec<ModuleSearchPath>);
|
||||
|
||||
impl Deref for OrderedSearchPaths {
|
||||
type Target = [ModuleSearchPath];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
// The singleton methods generated by salsa are all `pub` instead of `pub(crate)` which triggers
|
||||
// `unreachable_pub`. Work around this by creating a module and allow `unreachable_pub` for it.
|
||||
// Salsa also generates uses to `_db` variables for `interned` which triggers `clippy::used_underscore_binding`. Suppress that too
|
||||
// TODO(micha): Contribute a fix for this upstream where the singleton methods have the same visibility as the struct.
|
||||
#[allow(unreachable_pub, clippy::used_underscore_binding)]
|
||||
pub(crate) mod internal {
|
||||
use crate::module::resolver::OrderedSearchPaths;
|
||||
use crate::module::ModuleName;
|
||||
|
||||
#[salsa::input(singleton)]
|
||||
pub(crate) struct ModuleResolverSearchPaths {
|
||||
#[return_ref]
|
||||
pub(super) search_paths: OrderedSearchPaths,
|
||||
}
|
||||
|
||||
/// A thin wrapper around `ModuleName` to make it a Salsa ingredient.
|
||||
///
|
||||
/// This is needed because Salsa requires that all query arguments are salsa ingredients.
|
||||
#[salsa::interned]
|
||||
pub(crate) struct ModuleNameIngredient {
|
||||
#[return_ref]
|
||||
pub(super) name: ModuleName,
|
||||
}
|
||||
}
|
||||
|
||||
fn module_search_paths(db: &dyn Db) -> &[ModuleSearchPath] {
|
||||
ModuleResolverSearchPaths::get(db).search_paths(db)
|
||||
}
|
||||
|
||||
/// Given a module name and a list of search paths in which to lookup modules,
|
||||
/// attempt to resolve the module name
|
||||
fn resolve_name(db: &dyn Db, name: &ModuleName) -> Option<(ModuleSearchPath, VfsFile, ModuleKind)> {
|
||||
let search_paths = module_search_paths(db);
|
||||
|
||||
for search_path in search_paths {
|
||||
let mut components = name.components();
|
||||
let module_name = components.next_back()?;
|
||||
|
||||
let VfsPath::FileSystem(fs_search_path) = search_path.path() else {
|
||||
todo!("Vendored search paths are not yet supported");
|
||||
};
|
||||
|
||||
match resolve_package(db.file_system(), fs_search_path, components) {
|
||||
Ok(resolved_package) => {
|
||||
let mut package_path = resolved_package.path;
|
||||
|
||||
package_path.push(module_name);
|
||||
|
||||
// Must be a `__init__.pyi` or `__init__.py` or it isn't a package.
|
||||
let kind = if db.file_system().is_directory(&package_path) {
|
||||
package_path.push("__init__");
|
||||
ModuleKind::Package
|
||||
} else {
|
||||
ModuleKind::Module
|
||||
};
|
||||
|
||||
// TODO Implement full https://peps.python.org/pep-0561/#type-checker-module-resolution-order resolution
|
||||
let stub = package_path.with_extension("pyi");
|
||||
|
||||
if let Some(stub) = system_path_to_file(db.upcast(), &stub) {
|
||||
return Some((search_path.clone(), stub, kind));
|
||||
}
|
||||
|
||||
let module = package_path.with_extension("py");
|
||||
|
||||
if let Some(module) = system_path_to_file(db.upcast(), &module) {
|
||||
return Some((search_path.clone(), module, kind));
|
||||
}
|
||||
|
||||
// For regular packages, don't search the next search path. All files of that
|
||||
// package must be in the same location
|
||||
if resolved_package.kind.is_regular_package() {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
Err(parent_kind) => {
|
||||
if parent_kind.is_regular_package() {
|
||||
// For regular packages, don't search the next search path.
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn resolve_package<'a, I>(
|
||||
fs: &dyn FileSystem,
|
||||
module_search_path: &FileSystemPath,
|
||||
components: I,
|
||||
) -> Result<ResolvedPackage, PackageKind>
|
||||
where
|
||||
I: Iterator<Item = &'a str>,
|
||||
{
|
||||
let mut package_path = module_search_path.to_path_buf();
|
||||
|
||||
// `true` if inside a folder that is a namespace package (has no `__init__.py`).
|
||||
// Namespace packages are special because they can be spread across multiple search paths.
|
||||
// https://peps.python.org/pep-0420/
|
||||
let mut in_namespace_package = false;
|
||||
|
||||
// `true` if resolving a sub-package. For example, `true` when resolving `bar` of `foo.bar`.
|
||||
let mut in_sub_package = false;
|
||||
|
||||
// For `foo.bar.baz`, test that `foo` and `baz` both contain a `__init__.py`.
|
||||
for folder in components {
|
||||
package_path.push(folder);
|
||||
|
||||
let has_init_py = fs.is_file(&package_path.join("__init__.py"))
|
||||
|| fs.is_file(&package_path.join("__init__.pyi"));
|
||||
|
||||
if has_init_py {
|
||||
in_namespace_package = false;
|
||||
} else if fs.is_directory(&package_path) {
|
||||
// A directory without an `__init__.py` is a namespace package, continue with the next folder.
|
||||
in_namespace_package = true;
|
||||
} else if in_namespace_package {
|
||||
// Package not found but it is part of a namespace package.
|
||||
return Err(PackageKind::Namespace);
|
||||
} else if in_sub_package {
|
||||
// A regular sub package wasn't found.
|
||||
return Err(PackageKind::Regular);
|
||||
} else {
|
||||
// We couldn't find `foo` for `foo.bar.baz`, search the next search path.
|
||||
return Err(PackageKind::Root);
|
||||
}
|
||||
|
||||
in_sub_package = true;
|
||||
}
|
||||
|
||||
let kind = if in_namespace_package {
|
||||
PackageKind::Namespace
|
||||
} else if in_sub_package {
|
||||
PackageKind::Regular
|
||||
} else {
|
||||
PackageKind::Root
|
||||
};
|
||||
|
||||
Ok(ResolvedPackage {
|
||||
kind,
|
||||
path: package_path,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ResolvedPackage {
|
||||
path: FileSystemPathBuf,
|
||||
kind: PackageKind,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
|
||||
enum PackageKind {
|
||||
/// A root package or module. E.g. `foo` in `foo.bar.baz` or just `foo`.
|
||||
Root,
|
||||
|
||||
/// A regular sub-package where the parent contains an `__init__.py`.
|
||||
///
|
||||
/// For example, `bar` in `foo.bar` when the `foo` directory contains an `__init__.py`.
|
||||
Regular,
|
||||
|
||||
/// A sub-package in a namespace package. A namespace package is a package without an `__init__.py`.
|
||||
///
|
||||
/// For example, `bar` in `foo.bar` if the `foo` directory contains no `__init__.py`.
|
||||
Namespace,
|
||||
}
|
||||
|
||||
impl PackageKind {
|
||||
const fn is_regular_package(self) -> bool {
|
||||
matches!(self, PackageKind::Regular)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use ruff_db::file_system::{FileSystemPath, FileSystemPathBuf};
|
||||
use ruff_db::vfs::{system_path_to_file, VfsFile, VfsPath};
|
||||
|
||||
use crate::db::tests::TestDb;
|
||||
use crate::module::{ModuleKind, ModuleName};
|
||||
|
||||
use super::{
|
||||
path_to_module, resolve_module, set_module_resolution_settings, ModuleResolutionSettings,
|
||||
TYPESHED_STDLIB_DIRECTORY,
|
||||
};
|
||||
|
||||
struct TestCase {
|
||||
db: TestDb,
|
||||
|
||||
src: FileSystemPathBuf,
|
||||
custom_typeshed: FileSystemPathBuf,
|
||||
site_packages: FileSystemPathBuf,
|
||||
}
|
||||
|
||||
fn create_resolver() -> std::io::Result<TestCase> {
|
||||
let mut db = TestDb::new();
|
||||
|
||||
let src = FileSystemPath::new("src").to_path_buf();
|
||||
let site_packages = FileSystemPath::new("site_packages").to_path_buf();
|
||||
let custom_typeshed = FileSystemPath::new("typeshed").to_path_buf();
|
||||
|
||||
let fs = db.memory_file_system();
|
||||
|
||||
fs.create_directory_all(&src)?;
|
||||
fs.create_directory_all(&site_packages)?;
|
||||
fs.create_directory_all(&custom_typeshed)?;
|
||||
|
||||
let settings = ModuleResolutionSettings {
|
||||
extra_paths: vec![],
|
||||
workspace_root: src.clone(),
|
||||
site_packages: Some(site_packages.clone()),
|
||||
custom_typeshed: Some(custom_typeshed.clone()),
|
||||
};
|
||||
|
||||
set_module_resolution_settings(&mut db, settings);
|
||||
|
||||
Ok(TestCase {
|
||||
db,
|
||||
src,
|
||||
custom_typeshed,
|
||||
site_packages,
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_party_module() -> anyhow::Result<()> {
|
||||
let TestCase { db, src, .. } = create_resolver()?;
|
||||
|
||||
let foo_module_name = ModuleName::new_static("foo").unwrap();
|
||||
let foo_path = src.join("foo.py");
|
||||
db.memory_file_system()
|
||||
.write_file(&foo_path, "print('Hello, world!')")?;
|
||||
|
||||
let foo_module = resolve_module(&db, foo_module_name.clone()).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
Some(&foo_module),
|
||||
resolve_module(&db, foo_module_name.clone()).as_ref()
|
||||
);
|
||||
|
||||
assert_eq!("foo", foo_module.name());
|
||||
assert_eq!(&src, foo_module.search_path().path());
|
||||
assert_eq!(ModuleKind::Module, foo_module.kind());
|
||||
assert_eq!(&foo_path, foo_module.file().path(&db));
|
||||
|
||||
assert_eq!(
|
||||
Some(foo_module),
|
||||
path_to_module(&db, &VfsPath::FileSystem(foo_path))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stdlib() -> anyhow::Result<()> {
|
||||
let TestCase {
|
||||
db,
|
||||
custom_typeshed,
|
||||
..
|
||||
} = create_resolver()?;
|
||||
|
||||
let stdlib_dir = custom_typeshed.join(TYPESHED_STDLIB_DIRECTORY);
|
||||
let functools_path = stdlib_dir.join("functools.py");
|
||||
db.memory_file_system()
|
||||
.write_file(&functools_path, "def update_wrapper(): ...")?;
|
||||
|
||||
let functools_module_name = ModuleName::new_static("functools").unwrap();
|
||||
let functools_module = resolve_module(&db, functools_module_name.clone()).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
Some(&functools_module),
|
||||
resolve_module(&db, functools_module_name).as_ref()
|
||||
);
|
||||
|
||||
assert_eq!(&stdlib_dir, functools_module.search_path().path());
|
||||
assert_eq!(ModuleKind::Module, functools_module.kind());
|
||||
assert_eq!(&functools_path.clone(), functools_module.file().path(&db));
|
||||
|
||||
assert_eq!(
|
||||
Some(functools_module),
|
||||
path_to_module(&db, &VfsPath::FileSystem(functools_path))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_party_precedence_over_stdlib() -> anyhow::Result<()> {
|
||||
let TestCase {
|
||||
db,
|
||||
src,
|
||||
custom_typeshed,
|
||||
..
|
||||
} = create_resolver()?;
|
||||
|
||||
let stdlib_dir = custom_typeshed.join(TYPESHED_STDLIB_DIRECTORY);
|
||||
let stdlib_functools_path = stdlib_dir.join("functools.py");
|
||||
let first_party_functools_path = src.join("functools.py");
|
||||
|
||||
db.memory_file_system().write_files([
|
||||
(&stdlib_functools_path, "def update_wrapper(): ..."),
|
||||
(&first_party_functools_path, "def update_wrapper(): ..."),
|
||||
])?;
|
||||
|
||||
let functools_module_name = ModuleName::new_static("functools").unwrap();
|
||||
let functools_module = resolve_module(&db, functools_module_name.clone()).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
Some(&functools_module),
|
||||
resolve_module(&db, functools_module_name).as_ref()
|
||||
);
|
||||
assert_eq!(&src, functools_module.search_path().path());
|
||||
assert_eq!(ModuleKind::Module, functools_module.kind());
|
||||
assert_eq!(
|
||||
&first_party_functools_path.clone(),
|
||||
functools_module.file().path(&db)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Some(functools_module),
|
||||
path_to_module(&db, &VfsPath::FileSystem(first_party_functools_path))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TODO: Port typeshed test case. Porting isn't possible at the moment because the vendored zip
|
||||
// is part of the red knot crate
|
||||
// #[test]
|
||||
// fn typeshed_zip_created_at_build_time() -> anyhow::Result<()> {
|
||||
// // The file path here is hardcoded in this crate's `build.rs` script.
|
||||
// // Luckily this crate will fail to build if this file isn't available at build time.
|
||||
// const TYPESHED_ZIP_BYTES: &[u8] =
|
||||
// include_bytes!(concat!(env!("OUT_DIR"), "/zipped_typeshed.zip"));
|
||||
// assert!(!TYPESHED_ZIP_BYTES.is_empty());
|
||||
// let mut typeshed_zip_archive = ZipArchive::new(Cursor::new(TYPESHED_ZIP_BYTES))?;
|
||||
//
|
||||
// let path_to_functools = Path::new("stdlib").join("functools.pyi");
|
||||
// let mut functools_module_stub = typeshed_zip_archive
|
||||
// .by_name(path_to_functools.to_str().unwrap())
|
||||
// .unwrap();
|
||||
// assert!(functools_module_stub.is_file());
|
||||
//
|
||||
// let mut functools_module_stub_source = String::new();
|
||||
// functools_module_stub.read_to_string(&mut functools_module_stub_source)?;
|
||||
//
|
||||
// assert!(functools_module_stub_source.contains("def update_wrapper("));
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
#[test]
|
||||
fn resolve_package() -> anyhow::Result<()> {
|
||||
let TestCase { src, db, .. } = create_resolver()?;
|
||||
|
||||
let foo_dir = src.join("foo");
|
||||
let foo_path = foo_dir.join("__init__.py");
|
||||
|
||||
db.memory_file_system()
|
||||
.write_file(&foo_path, "print('Hello, world!')")?;
|
||||
|
||||
let foo_module = resolve_module(&db, ModuleName::new_static("foo").unwrap()).unwrap();
|
||||
|
||||
assert_eq!("foo", foo_module.name());
|
||||
assert_eq!(&src, foo_module.search_path().path());
|
||||
assert_eq!(&foo_path, foo_module.file().path(&db));
|
||||
|
||||
assert_eq!(
|
||||
Some(&foo_module),
|
||||
path_to_module(&db, &VfsPath::FileSystem(foo_path)).as_ref()
|
||||
);
|
||||
|
||||
// Resolving by directory doesn't resolve to the init file.
|
||||
assert_eq!(None, path_to_module(&db, &VfsPath::FileSystem(foo_dir)));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn package_priority_over_module() -> anyhow::Result<()> {
|
||||
let TestCase { db, src, .. } = create_resolver()?;
|
||||
|
||||
let foo_dir = src.join("foo");
|
||||
let foo_init = foo_dir.join("__init__.py");
|
||||
|
||||
db.memory_file_system()
|
||||
.write_file(&foo_init, "print('Hello, world!')")?;
|
||||
|
||||
let foo_py = src.join("foo.py");
|
||||
db.memory_file_system()
|
||||
.write_file(&foo_py, "print('Hello, world!')")?;
|
||||
|
||||
let foo_module = resolve_module(&db, ModuleName::new_static("foo").unwrap()).unwrap();
|
||||
|
||||
assert_eq!(&src, foo_module.search_path().path());
|
||||
assert_eq!(&foo_init, foo_module.file().path(&db));
|
||||
assert_eq!(ModuleKind::Package, foo_module.kind());
|
||||
|
||||
assert_eq!(
|
||||
Some(foo_module),
|
||||
path_to_module(&db, &VfsPath::FileSystem(foo_init))
|
||||
);
|
||||
assert_eq!(None, path_to_module(&db, &VfsPath::FileSystem(foo_py)));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn typing_stub_over_module() -> anyhow::Result<()> {
|
||||
let TestCase { db, src, .. } = create_resolver()?;
|
||||
|
||||
let foo_stub = src.join("foo.pyi");
|
||||
let foo_py = src.join("foo.py");
|
||||
db.memory_file_system()
|
||||
.write_files([(&foo_stub, "x: int"), (&foo_py, "print('Hello, world!')")])?;
|
||||
|
||||
let foo = resolve_module(&db, ModuleName::new_static("foo").unwrap()).unwrap();
|
||||
|
||||
assert_eq!(&src, foo.search_path().path());
|
||||
assert_eq!(&foo_stub, foo.file().path(&db));
|
||||
|
||||
assert_eq!(
|
||||
Some(foo),
|
||||
path_to_module(&db, &VfsPath::FileSystem(foo_stub))
|
||||
);
|
||||
assert_eq!(None, path_to_module(&db, &VfsPath::FileSystem(foo_py)));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub_packages() -> anyhow::Result<()> {
|
||||
let TestCase { db, src, .. } = create_resolver()?;
|
||||
|
||||
let foo = src.join("foo");
|
||||
let bar = foo.join("bar");
|
||||
let baz = bar.join("baz.py");
|
||||
|
||||
db.memory_file_system().write_files([
|
||||
(&foo.join("__init__.py"), ""),
|
||||
(&bar.join("__init__.py"), ""),
|
||||
(&baz, "print('Hello, world!')"),
|
||||
])?;
|
||||
|
||||
let baz_module =
|
||||
resolve_module(&db, ModuleName::new_static("foo.bar.baz").unwrap()).unwrap();
|
||||
|
||||
assert_eq!(&src, baz_module.search_path().path());
|
||||
assert_eq!(&baz, baz_module.file().path(&db));
|
||||
|
||||
assert_eq!(
|
||||
Some(baz_module),
|
||||
path_to_module(&db, &VfsPath::FileSystem(baz))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn namespace_package() -> anyhow::Result<()> {
|
||||
let TestCase {
|
||||
db,
|
||||
src,
|
||||
site_packages,
|
||||
..
|
||||
} = create_resolver()?;
|
||||
|
||||
// From [PEP420](https://peps.python.org/pep-0420/#nested-namespace-packages).
|
||||
// But uses `src` for `project1` and `site_packages2` for `project2`.
|
||||
// ```
|
||||
// src
|
||||
// parent
|
||||
// child
|
||||
// one.py
|
||||
// site_packages
|
||||
// parent
|
||||
// child
|
||||
// two.py
|
||||
// ```
|
||||
|
||||
let parent1 = src.join("parent");
|
||||
let child1 = parent1.join("child");
|
||||
let one = child1.join("one.py");
|
||||
|
||||
let parent2 = site_packages.join("parent");
|
||||
let child2 = parent2.join("child");
|
||||
let two = child2.join("two.py");
|
||||
|
||||
db.memory_file_system().write_files([
|
||||
(&one, "print('Hello, world!')"),
|
||||
(&two, "print('Hello, world!')"),
|
||||
])?;
|
||||
|
||||
let one_module =
|
||||
resolve_module(&db, ModuleName::new_static("parent.child.one").unwrap()).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
Some(one_module),
|
||||
path_to_module(&db, &VfsPath::FileSystem(one))
|
||||
);
|
||||
|
||||
let two_module =
|
||||
resolve_module(&db, ModuleName::new_static("parent.child.two").unwrap()).unwrap();
|
||||
assert_eq!(
|
||||
Some(two_module),
|
||||
path_to_module(&db, &VfsPath::FileSystem(two))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn regular_package_in_namespace_package() -> anyhow::Result<()> {
|
||||
let TestCase {
|
||||
db,
|
||||
src,
|
||||
site_packages,
|
||||
..
|
||||
} = create_resolver()?;
|
||||
|
||||
// Adopted test case from the [PEP420 examples](https://peps.python.org/pep-0420/#nested-namespace-packages).
|
||||
// The `src/parent/child` package is a regular package. Therefore, `site_packages/parent/child/two.py` should not be resolved.
|
||||
// ```
|
||||
// src
|
||||
// parent
|
||||
// child
|
||||
// one.py
|
||||
// site_packages
|
||||
// parent
|
||||
// child
|
||||
// two.py
|
||||
// ```
|
||||
|
||||
let parent1 = src.join("parent");
|
||||
let child1 = parent1.join("child");
|
||||
let one = child1.join("one.py");
|
||||
|
||||
let parent2 = site_packages.join("parent");
|
||||
let child2 = parent2.join("child");
|
||||
let two = child2.join("two.py");
|
||||
|
||||
db.memory_file_system().write_files([
|
||||
(&child1.join("__init__.py"), "print('Hello, world!')"),
|
||||
(&one, "print('Hello, world!')"),
|
||||
(&two, "print('Hello, world!')"),
|
||||
])?;
|
||||
|
||||
let one_module =
|
||||
resolve_module(&db, ModuleName::new_static("parent.child.one").unwrap()).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
Some(one_module),
|
||||
path_to_module(&db, &VfsPath::FileSystem(one))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
None,
|
||||
resolve_module(&db, ModuleName::new_static("parent.child.two").unwrap())
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn module_search_path_priority() -> anyhow::Result<()> {
|
||||
let TestCase {
|
||||
db,
|
||||
src,
|
||||
site_packages,
|
||||
..
|
||||
} = create_resolver()?;
|
||||
|
||||
let foo_src = src.join("foo.py");
|
||||
let foo_site_packages = site_packages.join("foo.py");
|
||||
|
||||
db.memory_file_system()
|
||||
.write_files([(&foo_src, ""), (&foo_site_packages, "")])?;
|
||||
|
||||
let foo_module = resolve_module(&db, ModuleName::new_static("foo").unwrap()).unwrap();
|
||||
|
||||
assert_eq!(&src, foo_module.search_path().path());
|
||||
assert_eq!(&foo_src, foo_module.file().path(&db));
|
||||
|
||||
assert_eq!(
|
||||
Some(foo_module),
|
||||
path_to_module(&db, &VfsPath::FileSystem(foo_src))
|
||||
);
|
||||
assert_eq!(
|
||||
None,
|
||||
path_to_module(&db, &VfsPath::FileSystem(foo_site_packages))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(target_family = "unix")]
|
||||
fn symlink() -> anyhow::Result<()> {
|
||||
let TestCase {
|
||||
mut db,
|
||||
src,
|
||||
site_packages,
|
||||
custom_typeshed,
|
||||
} = create_resolver()?;
|
||||
|
||||
db.with_os_file_system();
|
||||
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let root = FileSystemPath::from_std_path(temp_dir.path()).unwrap();
|
||||
|
||||
let src = root.join(src);
|
||||
let site_packages = root.join(site_packages);
|
||||
let custom_typeshed = root.join(custom_typeshed);
|
||||
|
||||
let foo = src.join("foo.py");
|
||||
let bar = src.join("bar.py");
|
||||
|
||||
std::fs::create_dir_all(src.as_std_path())?;
|
||||
std::fs::create_dir_all(site_packages.as_std_path())?;
|
||||
std::fs::create_dir_all(custom_typeshed.as_std_path())?;
|
||||
|
||||
std::fs::write(foo.as_std_path(), "")?;
|
||||
std::os::unix::fs::symlink(foo.as_std_path(), bar.as_std_path())?;
|
||||
|
||||
let settings = ModuleResolutionSettings {
|
||||
extra_paths: vec![],
|
||||
workspace_root: src.clone(),
|
||||
site_packages: Some(site_packages),
|
||||
custom_typeshed: Some(custom_typeshed),
|
||||
};
|
||||
|
||||
set_module_resolution_settings(&mut db, settings);
|
||||
|
||||
let foo_module = resolve_module(&db, ModuleName::new_static("foo").unwrap()).unwrap();
|
||||
let bar_module = resolve_module(&db, ModuleName::new_static("bar").unwrap()).unwrap();
|
||||
|
||||
assert_ne!(foo_module, bar_module);
|
||||
|
||||
assert_eq!(&src, foo_module.search_path().path());
|
||||
assert_eq!(&foo, foo_module.file().path(&db));
|
||||
|
||||
// `foo` and `bar` shouldn't resolve to the same file
|
||||
|
||||
assert_eq!(&src, bar_module.search_path().path());
|
||||
assert_eq!(&bar, bar_module.file().path(&db));
|
||||
assert_eq!(&foo, foo_module.file().path(&db));
|
||||
|
||||
assert_ne!(&foo_module, &bar_module);
|
||||
|
||||
assert_eq!(
|
||||
Some(foo_module),
|
||||
path_to_module(&db, &VfsPath::FileSystem(foo))
|
||||
);
|
||||
assert_eq!(
|
||||
Some(bar_module),
|
||||
path_to_module(&db, &VfsPath::FileSystem(bar))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deleting_an_unrealted_file_doesnt_change_module_resolution() -> anyhow::Result<()> {
|
||||
let TestCase { mut db, src, .. } = create_resolver()?;
|
||||
|
||||
let foo_path = src.join("foo.py");
|
||||
let bar_path = src.join("bar.py");
|
||||
|
||||
db.memory_file_system()
|
||||
.write_files([(&foo_path, "x = 1"), (&bar_path, "y = 2")])?;
|
||||
|
||||
let foo_module_name = ModuleName::new_static("foo").unwrap();
|
||||
let foo_module = resolve_module(&db, foo_module_name.clone()).unwrap();
|
||||
|
||||
let bar = system_path_to_file(&db, &bar_path).expect("bar.py to exist");
|
||||
|
||||
db.clear_salsa_events();
|
||||
|
||||
// Delete `bar.py`
|
||||
db.memory_file_system().remove_file(&bar_path)?;
|
||||
bar.touch(&mut db);
|
||||
|
||||
// Re-query the foo module. The foo module should still be cached because `bar.py` isn't relevant
|
||||
// for resolving `foo`.
|
||||
|
||||
let foo_module2 = resolve_module(&db, foo_module_name);
|
||||
|
||||
assert!(!db
|
||||
.take_salsa_events()
|
||||
.iter()
|
||||
.any(|event| { matches!(event.kind, salsa::EventKind::WillExecute { .. }) }));
|
||||
|
||||
assert_eq!(Some(foo_module), foo_module2);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adding_a_file_on_which_the_module_resolution_depends_on_invalidates_the_query(
|
||||
) -> anyhow::Result<()> {
|
||||
let TestCase { mut db, src, .. } = create_resolver()?;
|
||||
let foo_path = src.join("foo.py");
|
||||
|
||||
let foo_module_name = ModuleName::new_static("foo").unwrap();
|
||||
assert_eq!(resolve_module(&db, foo_module_name.clone()), None);
|
||||
|
||||
// Now write the foo file
|
||||
db.memory_file_system().write_file(&foo_path, "x = 1")?;
|
||||
VfsFile::touch_path(&mut db, &VfsPath::FileSystem(foo_path.clone()));
|
||||
let foo_file = system_path_to_file(&db, &foo_path).expect("foo.py to exist");
|
||||
|
||||
let foo_module = resolve_module(&db, foo_module_name).expect("Foo module to resolve");
|
||||
assert_eq!(foo_file, foo_module.file());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn removing_a_file_that_the_module_resolution_depends_on_invalidates_the_query(
|
||||
) -> anyhow::Result<()> {
|
||||
let TestCase { mut db, src, .. } = create_resolver()?;
|
||||
let foo_path = src.join("foo.py");
|
||||
let foo_init_path = src.join("foo/__init__.py");
|
||||
|
||||
db.memory_file_system()
|
||||
.write_files([(&foo_path, "x = 1"), (&foo_init_path, "x = 2")])?;
|
||||
|
||||
let foo_module_name = ModuleName::new_static("foo").unwrap();
|
||||
let foo_module = resolve_module(&db, foo_module_name.clone()).expect("foo module to exist");
|
||||
|
||||
assert_eq!(&foo_init_path, foo_module.file().path(&db));
|
||||
|
||||
// Delete `foo/__init__.py` and the `foo` folder. `foo` should now resolve to `foo.py`
|
||||
db.memory_file_system().remove_file(&foo_init_path)?;
|
||||
db.memory_file_system()
|
||||
.remove_directory(foo_init_path.parent().unwrap())?;
|
||||
VfsFile::touch_path(&mut db, &VfsPath::FileSystem(foo_init_path.clone()));
|
||||
|
||||
let foo_module = resolve_module(&db, foo_module_name).expect("Foo module to resolve");
|
||||
assert_eq!(&foo_path, foo_module.file().path(&db));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
56
crates/red_knot_python_semantic/src/name.rs
Normal file
56
crates/red_knot_python_semantic/src/name.rs
Normal file
|
@ -0,0 +1,56 @@
|
|||
use std::ops::Deref;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
|
||||
pub struct Name(smol_str::SmolStr);
|
||||
|
||||
impl Name {
|
||||
#[inline]
|
||||
pub fn new(name: &str) -> Self {
|
||||
Self(smol_str::SmolStr::new(name))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn new_static(name: &'static str) -> Self {
|
||||
Self(smol_str::SmolStr::new_static(name))
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &str {
|
||||
self.0.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Name {
|
||||
type Target = str;
|
||||
|
||||
#[inline]
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for Name
|
||||
where
|
||||
T: Into<smol_str::SmolStr>,
|
||||
{
|
||||
fn from(value: T) -> Self {
|
||||
Self(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Name {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<str> for Name {
|
||||
fn eq(&self, other: &str) -> bool {
|
||||
self.as_str() == other
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<Name> for str {
|
||||
fn eq(&self, other: &Name) -> bool {
|
||||
other == self
|
||||
}
|
||||
}
|
24
crates/red_knot_python_semantic/src/node_key.rs
Normal file
24
crates/red_knot_python_semantic/src/node_key.rs
Normal file
|
@ -0,0 +1,24 @@
|
|||
use ruff_python_ast::{AnyNodeRef, NodeKind};
|
||||
use ruff_text_size::{Ranged, TextRange};
|
||||
|
||||
/// Compact key for a node for use in a hash map.
|
||||
///
|
||||
/// Compares two nodes by their kind and text range.
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub(super) struct NodeKey {
|
||||
kind: NodeKind,
|
||||
range: TextRange,
|
||||
}
|
||||
|
||||
impl NodeKey {
|
||||
pub(super) fn from_node<'a, N>(node: N) -> Self
|
||||
where
|
||||
N: Into<AnyNodeRef<'a>>,
|
||||
{
|
||||
let node = node.into();
|
||||
NodeKey {
|
||||
kind: node.kind(),
|
||||
range: node.range(),
|
||||
}
|
||||
}
|
||||
}
|
668
crates/red_knot_python_semantic/src/semantic_index.rs
Normal file
668
crates/red_knot_python_semantic/src/semantic_index.rs
Normal file
|
@ -0,0 +1,668 @@
|
|||
use std::iter::FusedIterator;
|
||||
use std::sync::Arc;
|
||||
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use ruff_db::parsed::parsed_module;
|
||||
use ruff_db::vfs::VfsFile;
|
||||
use ruff_index::{IndexSlice, IndexVec};
|
||||
use ruff_python_ast as ast;
|
||||
|
||||
use crate::node_key::NodeKey;
|
||||
use crate::semantic_index::ast_ids::{AstId, AstIds, ScopeClassId, ScopeFunctionId};
|
||||
use crate::semantic_index::builder::SemanticIndexBuilder;
|
||||
use crate::semantic_index::symbol::{
|
||||
FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable,
|
||||
};
|
||||
use crate::Db;
|
||||
|
||||
pub mod ast_ids;
|
||||
mod builder;
|
||||
pub mod definition;
|
||||
pub mod symbol;
|
||||
|
||||
type SymbolMap = hashbrown::HashMap<ScopedSymbolId, (), ()>;
|
||||
|
||||
/// Returns the semantic index for `file`.
|
||||
///
|
||||
/// Prefer using [`symbol_table`] when working with symbols from a single scope.
|
||||
#[salsa::tracked(return_ref, no_eq)]
|
||||
pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex {
|
||||
let parsed = parsed_module(db.upcast(), file);
|
||||
|
||||
SemanticIndexBuilder::new(parsed).build()
|
||||
}
|
||||
|
||||
/// Returns the symbol table for a specific `scope`.
|
||||
///
|
||||
/// Using [`symbol_table`] over [`semantic_index`] has the advantage that
|
||||
/// Salsa can avoid invalidating dependent queries if this scope's symbol table
|
||||
/// is unchanged.
|
||||
#[salsa::tracked]
|
||||
pub(crate) fn symbol_table(db: &dyn Db, scope: ScopeId) -> Arc<SymbolTable> {
|
||||
let index = semantic_index(db, scope.file(db));
|
||||
|
||||
index.symbol_table(scope.file_scope_id(db))
|
||||
}
|
||||
|
||||
/// Returns the root scope of `file`.
|
||||
#[salsa::tracked]
|
||||
pub(crate) fn root_scope(db: &dyn Db, file: VfsFile) -> ScopeId {
|
||||
FileScopeId::root().to_scope_id(db, file)
|
||||
}
|
||||
|
||||
/// Returns the symbol with the given name in `file`'s public scope or `None` if
|
||||
/// no symbol with the given name exists.
|
||||
pub fn public_symbol(db: &dyn Db, file: VfsFile, name: &str) -> Option<PublicSymbolId> {
|
||||
let root_scope = root_scope(db, file);
|
||||
let symbol_table = symbol_table(db, root_scope);
|
||||
let local = symbol_table.symbol_id_by_name(name)?;
|
||||
Some(local.to_public_symbol(db, file))
|
||||
}
|
||||
|
||||
/// The symbol tables for an entire file.
|
||||
#[derive(Debug)]
|
||||
pub struct SemanticIndex {
|
||||
/// List of all symbol tables in this file, indexed by scope.
|
||||
symbol_tables: IndexVec<FileScopeId, Arc<SymbolTable>>,
|
||||
|
||||
/// List of all scopes in this file.
|
||||
scopes: IndexVec<FileScopeId, Scope>,
|
||||
|
||||
/// Maps expressions to their corresponding scope.
|
||||
/// We can't use [`ExpressionId`] here, because the challenge is how to get from
|
||||
/// an [`ast::Expr`] to an [`ExpressionId`] (which requires knowing the scope).
|
||||
expression_scopes: FxHashMap<NodeKey, FileScopeId>,
|
||||
|
||||
/// Lookup table to map between node ids and ast nodes.
|
||||
///
|
||||
/// Note: We should not depend on this map when analysing other files or
|
||||
/// changing a file invalidates all dependents.
|
||||
ast_ids: IndexVec<FileScopeId, AstIds>,
|
||||
|
||||
/// Map from scope to the node that introduces the scope.
|
||||
scope_nodes: IndexVec<FileScopeId, NodeWithScopeId>,
|
||||
}
|
||||
|
||||
impl SemanticIndex {
|
||||
/// Returns the symbol table for a specific scope.
|
||||
///
|
||||
/// Use the Salsa cached [`symbol_table`] query if you only need the
|
||||
/// symbol table for a single scope.
|
||||
pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable> {
|
||||
self.symbol_tables[scope_id].clone()
|
||||
}
|
||||
|
||||
pub(crate) fn ast_ids(&self, scope_id: FileScopeId) -> &AstIds {
|
||||
&self.ast_ids[scope_id]
|
||||
}
|
||||
|
||||
/// Returns the ID of the `expression`'s enclosing scope.
|
||||
#[allow(unused)]
|
||||
pub(crate) fn expression_scope_id(&self, expression: &ast::Expr) -> FileScopeId {
|
||||
self.expression_scopes[&NodeKey::from_node(expression)]
|
||||
}
|
||||
|
||||
/// Returns the [`Scope`] of the `expression`'s enclosing scope.
|
||||
#[allow(unused)]
|
||||
pub(crate) fn expression_scope(&self, expression: &ast::Expr) -> &Scope {
|
||||
&self.scopes[self.expression_scope_id(expression)]
|
||||
}
|
||||
|
||||
/// Returns the [`Scope`] with the given id.
|
||||
#[allow(unused)]
|
||||
pub(crate) fn scope(&self, id: FileScopeId) -> &Scope {
|
||||
&self.scopes[id]
|
||||
}
|
||||
|
||||
/// Returns the id of the parent scope.
|
||||
pub(crate) fn parent_scope_id(&self, scope_id: FileScopeId) -> Option<FileScopeId> {
|
||||
let scope = self.scope(scope_id);
|
||||
scope.parent
|
||||
}
|
||||
|
||||
/// Returns the parent scope of `scope_id`.
|
||||
#[allow(unused)]
|
||||
pub(crate) fn parent_scope(&self, scope_id: FileScopeId) -> Option<&Scope> {
|
||||
Some(&self.scopes[self.parent_scope_id(scope_id)?])
|
||||
}
|
||||
|
||||
/// Returns an iterator over the descendent scopes of `scope`.
|
||||
#[allow(unused)]
|
||||
pub(crate) fn descendent_scopes(&self, scope: FileScopeId) -> DescendentsIter {
|
||||
DescendentsIter::new(self, scope)
|
||||
}
|
||||
|
||||
/// Returns an iterator over the direct child scopes of `scope`.
|
||||
#[allow(unused)]
|
||||
pub(crate) fn child_scopes(&self, scope: FileScopeId) -> ChildrenIter {
|
||||
ChildrenIter::new(self, scope)
|
||||
}
|
||||
|
||||
/// Returns an iterator over all ancestors of `scope`, starting with `scope` itself.
|
||||
#[allow(unused)]
|
||||
pub(crate) fn ancestor_scopes(&self, scope: FileScopeId) -> AncestorsIter {
|
||||
AncestorsIter::new(self, scope)
|
||||
}
|
||||
|
||||
pub(crate) fn scope_node(&self, scope_id: FileScopeId) -> NodeWithScopeId {
|
||||
self.scope_nodes[scope_id]
|
||||
}
|
||||
}
|
||||
|
||||
/// ID that uniquely identifies an expression inside a [`Scope`].
|
||||
|
||||
pub struct AncestorsIter<'a> {
|
||||
scopes: &'a IndexSlice<FileScopeId, Scope>,
|
||||
next_id: Option<FileScopeId>,
|
||||
}
|
||||
|
||||
impl<'a> AncestorsIter<'a> {
|
||||
fn new(module_symbol_table: &'a SemanticIndex, start: FileScopeId) -> Self {
|
||||
Self {
|
||||
scopes: &module_symbol_table.scopes,
|
||||
next_id: Some(start),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for AncestorsIter<'a> {
|
||||
type Item = (FileScopeId, &'a Scope);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let current_id = self.next_id?;
|
||||
let current = &self.scopes[current_id];
|
||||
self.next_id = current.parent;
|
||||
|
||||
Some((current_id, current))
|
||||
}
|
||||
}
|
||||
|
||||
impl FusedIterator for AncestorsIter<'_> {}
|
||||
|
||||
pub struct DescendentsIter<'a> {
|
||||
next_id: FileScopeId,
|
||||
descendents: std::slice::Iter<'a, Scope>,
|
||||
}
|
||||
|
||||
impl<'a> DescendentsIter<'a> {
|
||||
fn new(symbol_table: &'a SemanticIndex, scope_id: FileScopeId) -> Self {
|
||||
let scope = &symbol_table.scopes[scope_id];
|
||||
let scopes = &symbol_table.scopes[scope.descendents.clone()];
|
||||
|
||||
Self {
|
||||
next_id: scope_id + 1,
|
||||
descendents: scopes.iter(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for DescendentsIter<'a> {
|
||||
type Item = (FileScopeId, &'a Scope);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let descendent = self.descendents.next()?;
|
||||
let id = self.next_id;
|
||||
self.next_id = self.next_id + 1;
|
||||
|
||||
Some((id, descendent))
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
self.descendents.size_hint()
|
||||
}
|
||||
}
|
||||
|
||||
impl FusedIterator for DescendentsIter<'_> {}
|
||||
|
||||
impl ExactSizeIterator for DescendentsIter<'_> {}
|
||||
|
||||
pub struct ChildrenIter<'a> {
|
||||
parent: FileScopeId,
|
||||
descendents: DescendentsIter<'a>,
|
||||
}
|
||||
|
||||
impl<'a> ChildrenIter<'a> {
|
||||
fn new(module_symbol_table: &'a SemanticIndex, parent: FileScopeId) -> Self {
|
||||
let descendents = DescendentsIter::new(module_symbol_table, parent);
|
||||
|
||||
Self {
|
||||
parent,
|
||||
descendents,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for ChildrenIter<'a> {
|
||||
type Item = (FileScopeId, &'a Scope);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.descendents
|
||||
.find(|(_, scope)| scope.parent == Some(self.parent))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub(crate) enum NodeWithScopeId {
|
||||
Module,
|
||||
Class(AstId<ScopeClassId>),
|
||||
ClassTypeParams(AstId<ScopeClassId>),
|
||||
Function(AstId<ScopeFunctionId>),
|
||||
FunctionTypeParams(AstId<ScopeFunctionId>),
|
||||
}
|
||||
|
||||
impl NodeWithScopeId {
|
||||
fn scope_kind(self) -> ScopeKind {
|
||||
match self {
|
||||
NodeWithScopeId::Module => ScopeKind::Module,
|
||||
NodeWithScopeId::Class(_) => ScopeKind::Class,
|
||||
NodeWithScopeId::Function(_) => ScopeKind::Function,
|
||||
NodeWithScopeId::ClassTypeParams(_) | NodeWithScopeId::FunctionTypeParams(_) => {
|
||||
ScopeKind::Annotation
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FusedIterator for ChildrenIter<'_> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use ruff_db::parsed::parsed_module;
|
||||
use ruff_db::vfs::{system_path_to_file, VfsFile};
|
||||
|
||||
use crate::db::tests::TestDb;
|
||||
use crate::semantic_index::symbol::{FileScopeId, ScopeKind, SymbolTable};
|
||||
use crate::semantic_index::{root_scope, semantic_index, symbol_table};
|
||||
|
||||
struct TestCase {
|
||||
db: TestDb,
|
||||
file: VfsFile,
|
||||
}
|
||||
|
||||
fn test_case(content: impl ToString) -> TestCase {
|
||||
let db = TestDb::new();
|
||||
db.memory_file_system()
|
||||
.write_file("test.py", content)
|
||||
.unwrap();
|
||||
|
||||
let file = system_path_to_file(&db, "test.py").unwrap();
|
||||
|
||||
TestCase { db, file }
|
||||
}
|
||||
|
||||
fn names(table: &SymbolTable) -> Vec<&str> {
|
||||
table
|
||||
.symbols()
|
||||
.map(|symbol| symbol.name().as_str())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty() {
|
||||
let TestCase { db, file } = test_case("");
|
||||
let root_table = symbol_table(&db, root_scope(&db, file));
|
||||
|
||||
assert_eq!(names(&root_table), Vec::<&str>::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn simple() {
|
||||
let TestCase { db, file } = test_case("x");
|
||||
let root_table = symbol_table(&db, root_scope(&db, file));
|
||||
|
||||
assert_eq!(names(&root_table), vec!["x"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn annotation_only() {
|
||||
let TestCase { db, file } = test_case("x: int");
|
||||
let root_table = symbol_table(&db, root_scope(&db, file));
|
||||
|
||||
assert_eq!(names(&root_table), vec!["int", "x"]);
|
||||
// TODO record definition
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn import() {
|
||||
let TestCase { db, file } = test_case("import foo");
|
||||
let root_table = symbol_table(&db, root_scope(&db, file));
|
||||
|
||||
assert_eq!(names(&root_table), vec!["foo"]);
|
||||
let foo = root_table.symbol_by_name("foo").unwrap();
|
||||
|
||||
assert_eq!(foo.definitions().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn import_sub() {
|
||||
let TestCase { db, file } = test_case("import foo.bar");
|
||||
let root_table = symbol_table(&db, root_scope(&db, file));
|
||||
|
||||
assert_eq!(names(&root_table), vec!["foo"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn import_as() {
|
||||
let TestCase { db, file } = test_case("import foo.bar as baz");
|
||||
let root_table = symbol_table(&db, root_scope(&db, file));
|
||||
|
||||
assert_eq!(names(&root_table), vec!["baz"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn import_from() {
|
||||
let TestCase { db, file } = test_case("from bar import foo");
|
||||
let root_table = symbol_table(&db, root_scope(&db, file));
|
||||
|
||||
assert_eq!(names(&root_table), vec!["foo"]);
|
||||
assert_eq!(
|
||||
root_table
|
||||
.symbol_by_name("foo")
|
||||
.unwrap()
|
||||
.definitions()
|
||||
.len(),
|
||||
1
|
||||
);
|
||||
assert!(
|
||||
root_table
|
||||
.symbol_by_name("foo")
|
||||
.is_some_and(|symbol| { symbol.is_defined() || !symbol.is_used() }),
|
||||
"symbols that are defined get the defined flag"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn assign() {
|
||||
let TestCase { db, file } = test_case("x = foo");
|
||||
let root_table = symbol_table(&db, root_scope(&db, file));
|
||||
|
||||
assert_eq!(names(&root_table), vec!["foo", "x"]);
|
||||
assert_eq!(
|
||||
root_table.symbol_by_name("x").unwrap().definitions().len(),
|
||||
1
|
||||
);
|
||||
assert!(
|
||||
root_table
|
||||
.symbol_by_name("foo")
|
||||
.is_some_and(|symbol| { !symbol.is_defined() && symbol.is_used() }),
|
||||
"a symbol used but not defined in a scope should have only the used flag"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn class_scope() {
|
||||
let TestCase { db, file } = test_case(
|
||||
"
|
||||
class C:
|
||||
x = 1
|
||||
y = 2
|
||||
",
|
||||
);
|
||||
let root_table = symbol_table(&db, root_scope(&db, file));
|
||||
|
||||
assert_eq!(names(&root_table), vec!["C", "y"]);
|
||||
|
||||
let index = semantic_index(&db, file);
|
||||
|
||||
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
|
||||
assert_eq!(scopes.len(), 1);
|
||||
|
||||
let (class_scope_id, class_scope) = scopes[0];
|
||||
assert_eq!(class_scope.kind(), ScopeKind::Class);
|
||||
assert_eq!(class_scope.name(), "C");
|
||||
|
||||
let class_table = index.symbol_table(class_scope_id);
|
||||
assert_eq!(names(&class_table), vec!["x"]);
|
||||
assert_eq!(
|
||||
class_table.symbol_by_name("x").unwrap().definitions().len(),
|
||||
1
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn function_scope() {
|
||||
let TestCase { db, file } = test_case(
|
||||
"
|
||||
def func():
|
||||
x = 1
|
||||
y = 2
|
||||
",
|
||||
);
|
||||
let index = semantic_index(&db, file);
|
||||
let root_table = index.symbol_table(FileScopeId::root());
|
||||
|
||||
assert_eq!(names(&root_table), vec!["func", "y"]);
|
||||
|
||||
let scopes = index.child_scopes(FileScopeId::root()).collect::<Vec<_>>();
|
||||
assert_eq!(scopes.len(), 1);
|
||||
|
||||
let (function_scope_id, function_scope) = scopes[0];
|
||||
assert_eq!(function_scope.kind(), ScopeKind::Function);
|
||||
assert_eq!(function_scope.name(), "func");
|
||||
|
||||
let function_table = index.symbol_table(function_scope_id);
|
||||
assert_eq!(names(&function_table), vec!["x"]);
|
||||
assert_eq!(
|
||||
function_table
|
||||
.symbol_by_name("x")
|
||||
.unwrap()
|
||||
.definitions()
|
||||
.len(),
|
||||
1
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dupes() {
|
||||
let TestCase { db, file } = test_case(
|
||||
"
|
||||
def func():
|
||||
x = 1
|
||||
def func():
|
||||
y = 2
|
||||
",
|
||||
);
|
||||
let index = semantic_index(&db, file);
|
||||
let root_table = index.symbol_table(FileScopeId::root());
|
||||
|
||||
assert_eq!(names(&root_table), vec!["func"]);
|
||||
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
|
||||
assert_eq!(scopes.len(), 2);
|
||||
|
||||
let (func_scope1_id, func_scope_1) = scopes[0];
|
||||
let (func_scope2_id, func_scope_2) = scopes[1];
|
||||
|
||||
assert_eq!(func_scope_1.kind(), ScopeKind::Function);
|
||||
assert_eq!(func_scope_1.name(), "func");
|
||||
assert_eq!(func_scope_2.kind(), ScopeKind::Function);
|
||||
assert_eq!(func_scope_2.name(), "func");
|
||||
|
||||
let func1_table = index.symbol_table(func_scope1_id);
|
||||
let func2_table = index.symbol_table(func_scope2_id);
|
||||
assert_eq!(names(&func1_table), vec!["x"]);
|
||||
assert_eq!(names(&func2_table), vec!["y"]);
|
||||
assert_eq!(
|
||||
root_table
|
||||
.symbol_by_name("func")
|
||||
.unwrap()
|
||||
.definitions()
|
||||
.len(),
|
||||
2
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_function() {
|
||||
let TestCase { db, file } = test_case(
|
||||
"
|
||||
def func[T]():
|
||||
x = 1
|
||||
",
|
||||
);
|
||||
|
||||
let index = semantic_index(&db, file);
|
||||
let root_table = index.symbol_table(FileScopeId::root());
|
||||
|
||||
assert_eq!(names(&root_table), vec!["func"]);
|
||||
|
||||
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
|
||||
assert_eq!(scopes.len(), 1);
|
||||
let (ann_scope_id, ann_scope) = scopes[0];
|
||||
|
||||
assert_eq!(ann_scope.kind(), ScopeKind::Annotation);
|
||||
assert_eq!(ann_scope.name(), "func");
|
||||
let ann_table = index.symbol_table(ann_scope_id);
|
||||
assert_eq!(names(&ann_table), vec!["T"]);
|
||||
|
||||
let scopes: Vec<_> = index.child_scopes(ann_scope_id).collect();
|
||||
assert_eq!(scopes.len(), 1);
|
||||
let (func_scope_id, func_scope) = scopes[0];
|
||||
assert_eq!(func_scope.kind(), ScopeKind::Function);
|
||||
assert_eq!(func_scope.name(), "func");
|
||||
let func_table = index.symbol_table(func_scope_id);
|
||||
assert_eq!(names(&func_table), vec!["x"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_class() {
|
||||
let TestCase { db, file } = test_case(
|
||||
"
|
||||
class C[T]:
|
||||
x = 1
|
||||
",
|
||||
);
|
||||
|
||||
let index = semantic_index(&db, file);
|
||||
let root_table = index.symbol_table(FileScopeId::root());
|
||||
|
||||
assert_eq!(names(&root_table), vec!["C"]);
|
||||
|
||||
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
|
||||
|
||||
assert_eq!(scopes.len(), 1);
|
||||
let (ann_scope_id, ann_scope) = scopes[0];
|
||||
assert_eq!(ann_scope.kind(), ScopeKind::Annotation);
|
||||
assert_eq!(ann_scope.name(), "C");
|
||||
let ann_table = index.symbol_table(ann_scope_id);
|
||||
assert_eq!(names(&ann_table), vec!["T"]);
|
||||
assert!(
|
||||
ann_table
|
||||
.symbol_by_name("T")
|
||||
.is_some_and(|s| s.is_defined() && !s.is_used()),
|
||||
"type parameters are defined by the scope that introduces them"
|
||||
);
|
||||
|
||||
let scopes: Vec<_> = index.child_scopes(ann_scope_id).collect();
|
||||
assert_eq!(scopes.len(), 1);
|
||||
let (func_scope_id, func_scope) = scopes[0];
|
||||
|
||||
assert_eq!(func_scope.kind(), ScopeKind::Class);
|
||||
assert_eq!(func_scope.name(), "C");
|
||||
assert_eq!(names(&index.symbol_table(func_scope_id)), vec!["x"]);
|
||||
}
|
||||
|
||||
// TODO: After porting the control flow graph.
|
||||
// #[test]
|
||||
// fn reachability_trivial() {
|
||||
// let parsed = parse("x = 1; x");
|
||||
// let ast = parsed.syntax();
|
||||
// let index = SemanticIndex::from_ast(ast);
|
||||
// let table = &index.symbol_table;
|
||||
// let x_sym = table
|
||||
// .root_symbol_id_by_name("x")
|
||||
// .expect("x symbol should exist");
|
||||
// let ast::Stmt::Expr(ast::StmtExpr { value: x_use, .. }) = &ast.body[1] else {
|
||||
// panic!("should be an expr")
|
||||
// };
|
||||
// let x_defs: Vec<_> = index
|
||||
// .reachable_definitions(x_sym, x_use)
|
||||
// .map(|constrained_definition| constrained_definition.definition)
|
||||
// .collect();
|
||||
// assert_eq!(x_defs.len(), 1);
|
||||
// let Definition::Assignment(node_key) = &x_defs[0] else {
|
||||
// panic!("def should be an assignment")
|
||||
// };
|
||||
// let Some(def_node) = node_key.resolve(ast.into()) else {
|
||||
// panic!("node key should resolve")
|
||||
// };
|
||||
// let ast::Expr::NumberLiteral(ast::ExprNumberLiteral {
|
||||
// value: ast::Number::Int(num),
|
||||
// ..
|
||||
// }) = &*def_node.value
|
||||
// else {
|
||||
// panic!("should be a number literal")
|
||||
// };
|
||||
// assert_eq!(*num, 1);
|
||||
// }
|
||||
|
||||
#[test]
|
||||
fn expression_scope() {
|
||||
let TestCase { db, file } = test_case("x = 1;\ndef test():\n y = 4");
|
||||
|
||||
let index = semantic_index(&db, file);
|
||||
let parsed = parsed_module(&db, file);
|
||||
let ast = parsed.syntax();
|
||||
|
||||
let x_stmt = ast.body[0].as_assign_stmt().unwrap();
|
||||
let x = &x_stmt.targets[0];
|
||||
|
||||
assert_eq!(index.expression_scope(x).kind(), ScopeKind::Module);
|
||||
assert_eq!(index.expression_scope_id(x), FileScopeId::root());
|
||||
|
||||
let def = ast.body[1].as_function_def_stmt().unwrap();
|
||||
let y_stmt = def.body[0].as_assign_stmt().unwrap();
|
||||
let y = &y_stmt.targets[0];
|
||||
|
||||
assert_eq!(index.expression_scope(y).kind(), ScopeKind::Function);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scope_iterators() {
|
||||
let TestCase { db, file } = test_case(
|
||||
r#"
|
||||
class Test:
|
||||
def foo():
|
||||
def bar():
|
||||
...
|
||||
def baz():
|
||||
pass
|
||||
|
||||
def x():
|
||||
pass"#,
|
||||
);
|
||||
|
||||
let index = semantic_index(&db, file);
|
||||
|
||||
let descendents: Vec<_> = index
|
||||
.descendent_scopes(FileScopeId::root())
|
||||
.map(|(_, scope)| scope.name().as_str())
|
||||
.collect();
|
||||
assert_eq!(descendents, vec!["Test", "foo", "bar", "baz", "x"]);
|
||||
|
||||
let children: Vec<_> = index
|
||||
.child_scopes(FileScopeId::root())
|
||||
.map(|(_, scope)| scope.name.as_str())
|
||||
.collect();
|
||||
assert_eq!(children, vec!["Test", "x"]);
|
||||
|
||||
let test_class = index.child_scopes(FileScopeId::root()).next().unwrap().0;
|
||||
let test_child_scopes: Vec<_> = index
|
||||
.child_scopes(test_class)
|
||||
.map(|(_, scope)| scope.name.as_str())
|
||||
.collect();
|
||||
assert_eq!(test_child_scopes, vec!["foo", "baz"]);
|
||||
|
||||
let bar_scope = index
|
||||
.descendent_scopes(FileScopeId::root())
|
||||
.nth(2)
|
||||
.unwrap()
|
||||
.0;
|
||||
let ancestors: Vec<_> = index
|
||||
.ancestor_scopes(bar_scope)
|
||||
.map(|(_, scope)| scope.name())
|
||||
.collect();
|
||||
|
||||
assert_eq!(ancestors, vec!["bar", "foo", "Test", "<module>"]);
|
||||
}
|
||||
}
|
393
crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs
Normal file
393
crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs
Normal file
|
@ -0,0 +1,393 @@
|
|||
use rustc_hash::FxHashMap;
|
||||
|
||||
use ruff_db::parsed::ParsedModule;
|
||||
use ruff_db::vfs::VfsFile;
|
||||
use ruff_index::{newtype_index, IndexVec};
|
||||
use ruff_python_ast as ast;
|
||||
use ruff_python_ast::AnyNodeRef;
|
||||
|
||||
use crate::ast_node_ref::AstNodeRef;
|
||||
use crate::node_key::NodeKey;
|
||||
use crate::semantic_index::semantic_index;
|
||||
use crate::semantic_index::symbol::{FileScopeId, ScopeId};
|
||||
use crate::Db;
|
||||
|
||||
/// AST ids for a single scope.
|
||||
///
|
||||
/// The motivation for building the AST ids per scope isn't about reducing invalidation because
|
||||
/// the struct changes whenever the parsed AST changes. Instead, it's mainly that we can
|
||||
/// build the AST ids struct when building the symbol table and also keep the property that
|
||||
/// IDs of outer scopes are unaffected by changes in inner scopes.
|
||||
///
|
||||
/// For example, we don't want that adding new statements to `foo` changes the statement id of `x = foo()` in:
|
||||
///
|
||||
/// ```python
|
||||
/// def foo():
|
||||
/// return 5
|
||||
///
|
||||
/// x = foo()
|
||||
/// ```
|
||||
pub(crate) struct AstIds {
|
||||
/// Maps expression ids to their expressions.
|
||||
expressions: IndexVec<ScopeExpressionId, AstNodeRef<ast::Expr>>,
|
||||
|
||||
/// Maps expressions to their expression id. Uses `NodeKey` because it avoids cloning [`Parsed`].
|
||||
expressions_map: FxHashMap<NodeKey, ScopeExpressionId>,
|
||||
|
||||
statements: IndexVec<ScopeStatementId, AstNodeRef<ast::Stmt>>,
|
||||
|
||||
statements_map: FxHashMap<NodeKey, ScopeStatementId>,
|
||||
}
|
||||
|
||||
impl AstIds {
|
||||
fn statement_id<'a, N>(&self, node: N) -> ScopeStatementId
|
||||
where
|
||||
N: Into<AnyNodeRef<'a>>,
|
||||
{
|
||||
self.statements_map[&NodeKey::from_node(node.into())]
|
||||
}
|
||||
|
||||
fn expression_id<'a, N>(&self, node: N) -> ScopeExpressionId
|
||||
where
|
||||
N: Into<AnyNodeRef<'a>>,
|
||||
{
|
||||
self.expressions_map[&NodeKey::from_node(node.into())]
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::missing_fields_in_debug)]
|
||||
impl std::fmt::Debug for AstIds {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AstIds")
|
||||
.field("expressions", &self.expressions)
|
||||
.field("statements", &self.statements)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
fn ast_ids(db: &dyn Db, scope: ScopeId) -> &AstIds {
|
||||
semantic_index(db, scope.file(db)).ast_ids(scope.file_scope_id(db))
|
||||
}
|
||||
|
||||
/// Node that can be uniquely identified by an id in a [`FileScopeId`].
|
||||
pub trait ScopeAstIdNode {
|
||||
/// The type of the ID uniquely identifying the node.
|
||||
type Id: Copy;
|
||||
|
||||
/// Returns the ID that uniquely identifies the node in `scope`.
|
||||
///
|
||||
/// ## Panics
|
||||
/// Panics if the node doesn't belong to `file` or is outside `scope`.
|
||||
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> Self::Id;
|
||||
|
||||
/// Looks up the AST node by its ID.
|
||||
///
|
||||
/// ## Panics
|
||||
/// May panic if the `id` does not belong to the AST of `file`, or is outside `scope`.
|
||||
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
/// Extension trait for AST nodes that can be resolved by an `AstId`.
|
||||
pub trait AstIdNode {
|
||||
type ScopeId: Copy;
|
||||
|
||||
/// Resolves the AST id of the node.
|
||||
///
|
||||
/// ## Panics
|
||||
/// May panic if the node does not belongs to `file`'s AST or is outside of `scope`. It may also
|
||||
/// return an incorrect node if that's the case.
|
||||
fn ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> AstId<Self::ScopeId>;
|
||||
|
||||
/// Resolves the AST node for `id`.
|
||||
///
|
||||
/// ## Panics
|
||||
/// May panic if the `id` does not belong to the AST of `file` or it returns an incorrect node.
|
||||
|
||||
fn lookup(db: &dyn Db, file: VfsFile, id: AstId<Self::ScopeId>) -> &Self
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
impl<T> AstIdNode for T
|
||||
where
|
||||
T: ScopeAstIdNode,
|
||||
{
|
||||
type ScopeId = T::Id;
|
||||
|
||||
fn ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> AstId<Self::ScopeId> {
|
||||
let in_scope_id = self.scope_ast_id(db, file, scope);
|
||||
AstId { scope, in_scope_id }
|
||||
}
|
||||
|
||||
fn lookup(db: &dyn Db, file: VfsFile, id: AstId<Self::ScopeId>) -> &Self
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let scope = id.scope;
|
||||
Self::lookup_in_scope(db, file, scope, id.in_scope_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Uniquely identifies an AST node in a file.
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct AstId<L: Copy> {
|
||||
/// The node's scope.
|
||||
scope: FileScopeId,
|
||||
|
||||
/// The ID of the node inside [`Self::scope`].
|
||||
in_scope_id: L,
|
||||
}
|
||||
|
||||
impl<L: Copy> AstId<L> {
|
||||
pub(super) fn new(scope: FileScopeId, in_scope_id: L) -> Self {
|
||||
Self { scope, in_scope_id }
|
||||
}
|
||||
|
||||
pub(super) fn in_scope_id(self) -> L {
|
||||
self.in_scope_id
|
||||
}
|
||||
}
|
||||
|
||||
/// Uniquely identifies an [`ast::Expr`] in a [`FileScopeId`].
|
||||
#[newtype_index]
|
||||
pub struct ScopeExpressionId;
|
||||
|
||||
impl ScopeAstIdNode for ast::Expr {
|
||||
type Id = ScopeExpressionId;
|
||||
|
||||
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
|
||||
let scope = file_scope.to_scope_id(db, file);
|
||||
let ast_ids = ast_ids(db, scope);
|
||||
ast_ids.expressions_map[&NodeKey::from_node(self)]
|
||||
}
|
||||
|
||||
fn lookup_in_scope(db: &dyn Db, file: VfsFile, file_scope: FileScopeId, id: Self::Id) -> &Self {
|
||||
let scope = file_scope.to_scope_id(db, file);
|
||||
let ast_ids = ast_ids(db, scope);
|
||||
ast_ids.expressions[id].node()
|
||||
}
|
||||
}
|
||||
|
||||
/// Uniquely identifies an [`ast::Stmt`] in a [`FileScopeId`].
|
||||
#[newtype_index]
|
||||
pub struct ScopeStatementId;
|
||||
|
||||
impl ScopeAstIdNode for ast::Stmt {
|
||||
type Id = ScopeStatementId;
|
||||
|
||||
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
|
||||
let scope = file_scope.to_scope_id(db, file);
|
||||
let ast_ids = ast_ids(db, scope);
|
||||
ast_ids.statement_id(self)
|
||||
}
|
||||
|
||||
fn lookup_in_scope(db: &dyn Db, file: VfsFile, file_scope: FileScopeId, id: Self::Id) -> &Self {
|
||||
let scope = file_scope.to_scope_id(db, file);
|
||||
let ast_ids = ast_ids(db, scope);
|
||||
|
||||
ast_ids.statements[id].node()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
|
||||
pub struct ScopeFunctionId(pub(super) ScopeStatementId);
|
||||
|
||||
impl ScopeAstIdNode for ast::StmtFunctionDef {
|
||||
type Id = ScopeFunctionId;
|
||||
|
||||
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
|
||||
let scope = file_scope.to_scope_id(db, file);
|
||||
let ast_ids = ast_ids(db, scope);
|
||||
ScopeFunctionId(ast_ids.statement_id(self))
|
||||
}
|
||||
|
||||
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
|
||||
ast::Stmt::lookup_in_scope(db, file, scope, id.0)
|
||||
.as_function_def_stmt()
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
|
||||
pub struct ScopeClassId(pub(super) ScopeStatementId);
|
||||
|
||||
impl ScopeAstIdNode for ast::StmtClassDef {
|
||||
type Id = ScopeClassId;
|
||||
|
||||
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
|
||||
let scope = file_scope.to_scope_id(db, file);
|
||||
let ast_ids = ast_ids(db, scope);
|
||||
ScopeClassId(ast_ids.statement_id(self))
|
||||
}
|
||||
|
||||
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
|
||||
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
|
||||
statement.as_class_def_stmt().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
|
||||
pub struct ScopeAssignmentId(pub(super) ScopeStatementId);
|
||||
|
||||
impl ScopeAstIdNode for ast::StmtAssign {
|
||||
type Id = ScopeAssignmentId;
|
||||
|
||||
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
|
||||
let scope = file_scope.to_scope_id(db, file);
|
||||
let ast_ids = ast_ids(db, scope);
|
||||
ScopeAssignmentId(ast_ids.statement_id(self))
|
||||
}
|
||||
|
||||
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
|
||||
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
|
||||
statement.as_assign_stmt().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
|
||||
pub struct ScopeAnnotatedAssignmentId(ScopeStatementId);
|
||||
|
||||
impl ScopeAstIdNode for ast::StmtAnnAssign {
|
||||
type Id = ScopeAnnotatedAssignmentId;
|
||||
|
||||
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
|
||||
let scope = file_scope.to_scope_id(db, file);
|
||||
let ast_ids = ast_ids(db, scope);
|
||||
ScopeAnnotatedAssignmentId(ast_ids.statement_id(self))
|
||||
}
|
||||
|
||||
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
|
||||
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
|
||||
statement.as_ann_assign_stmt().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
|
||||
pub struct ScopeImportId(pub(super) ScopeStatementId);
|
||||
|
||||
impl ScopeAstIdNode for ast::StmtImport {
|
||||
type Id = ScopeImportId;
|
||||
|
||||
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
|
||||
let scope = file_scope.to_scope_id(db, file);
|
||||
let ast_ids = ast_ids(db, scope);
|
||||
ScopeImportId(ast_ids.statement_id(self))
|
||||
}
|
||||
|
||||
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
|
||||
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
|
||||
statement.as_import_stmt().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
|
||||
pub struct ScopeImportFromId(pub(super) ScopeStatementId);
|
||||
|
||||
impl ScopeAstIdNode for ast::StmtImportFrom {
|
||||
type Id = ScopeImportFromId;
|
||||
|
||||
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
|
||||
let scope = file_scope.to_scope_id(db, file);
|
||||
let ast_ids = ast_ids(db, scope);
|
||||
ScopeImportFromId(ast_ids.statement_id(self))
|
||||
}
|
||||
|
||||
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
|
||||
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
|
||||
statement.as_import_from_stmt().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
|
||||
pub struct ScopeNamedExprId(pub(super) ScopeExpressionId);
|
||||
|
||||
impl ScopeAstIdNode for ast::ExprNamed {
|
||||
type Id = ScopeNamedExprId;
|
||||
|
||||
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
|
||||
let scope = file_scope.to_scope_id(db, file);
|
||||
let ast_ids = ast_ids(db, scope);
|
||||
ScopeNamedExprId(ast_ids.expression_id(self))
|
||||
}
|
||||
|
||||
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let expression = ast::Expr::lookup_in_scope(db, file, scope, id.0);
|
||||
expression.as_named_expr().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct AstIdsBuilder {
|
||||
expressions: IndexVec<ScopeExpressionId, AstNodeRef<ast::Expr>>,
|
||||
expressions_map: FxHashMap<NodeKey, ScopeExpressionId>,
|
||||
statements: IndexVec<ScopeStatementId, AstNodeRef<ast::Stmt>>,
|
||||
statements_map: FxHashMap<NodeKey, ScopeStatementId>,
|
||||
}
|
||||
|
||||
impl AstIdsBuilder {
|
||||
pub(super) fn new() -> Self {
|
||||
Self {
|
||||
expressions: IndexVec::default(),
|
||||
expressions_map: FxHashMap::default(),
|
||||
statements: IndexVec::default(),
|
||||
statements_map: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds `stmt` to the AST ids map and returns its id.
|
||||
///
|
||||
/// ## Safety
|
||||
/// The function is marked as unsafe because it calls [`AstNodeRef::new`] which requires
|
||||
/// that `stmt` is a child of `parsed`.
|
||||
#[allow(unsafe_code)]
|
||||
pub(super) unsafe fn record_statement(
|
||||
&mut self,
|
||||
stmt: &ast::Stmt,
|
||||
parsed: &ParsedModule,
|
||||
) -> ScopeStatementId {
|
||||
let statement_id = self.statements.push(AstNodeRef::new(parsed.clone(), stmt));
|
||||
|
||||
self.statements_map
|
||||
.insert(NodeKey::from_node(stmt), statement_id);
|
||||
|
||||
statement_id
|
||||
}
|
||||
|
||||
/// Adds `expr` to the AST ids map and returns its id.
|
||||
///
|
||||
/// ## Safety
|
||||
/// The function is marked as unsafe because it calls [`AstNodeRef::new`] which requires
|
||||
/// that `expr` is a child of `parsed`.
|
||||
#[allow(unsafe_code)]
|
||||
pub(super) unsafe fn record_expression(
|
||||
&mut self,
|
||||
expr: &ast::Expr,
|
||||
parsed: &ParsedModule,
|
||||
) -> ScopeExpressionId {
|
||||
let expression_id = self.expressions.push(AstNodeRef::new(parsed.clone(), expr));
|
||||
|
||||
self.expressions_map
|
||||
.insert(NodeKey::from_node(expr), expression_id);
|
||||
|
||||
expression_id
|
||||
}
|
||||
|
||||
pub(super) fn finish(mut self) -> AstIds {
|
||||
self.expressions.shrink_to_fit();
|
||||
self.expressions_map.shrink_to_fit();
|
||||
self.statements.shrink_to_fit();
|
||||
self.statements_map.shrink_to_fit();
|
||||
|
||||
AstIds {
|
||||
expressions: self.expressions,
|
||||
expressions_map: self.expressions_map,
|
||||
statements: self.statements,
|
||||
statements_map: self.statements_map,
|
||||
}
|
||||
}
|
||||
}
|
454
crates/red_knot_python_semantic/src/semantic_index/builder.rs
Normal file
454
crates/red_knot_python_semantic/src/semantic_index/builder.rs
Normal file
|
@ -0,0 +1,454 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use ruff_db::parsed::ParsedModule;
|
||||
use ruff_index::IndexVec;
|
||||
use ruff_python_ast as ast;
|
||||
use ruff_python_ast::visitor::{walk_expr, walk_stmt, Visitor};
|
||||
|
||||
use crate::name::Name;
|
||||
use crate::node_key::NodeKey;
|
||||
use crate::semantic_index::ast_ids::{
|
||||
AstId, AstIdsBuilder, ScopeAssignmentId, ScopeClassId, ScopeFunctionId, ScopeImportFromId,
|
||||
ScopeImportId, ScopeNamedExprId,
|
||||
};
|
||||
use crate::semantic_index::definition::{Definition, ImportDefinition, ImportFromDefinition};
|
||||
use crate::semantic_index::symbol::{
|
||||
FileScopeId, FileSymbolId, Scope, ScopedSymbolId, SymbolFlags, SymbolTableBuilder,
|
||||
};
|
||||
use crate::semantic_index::{NodeWithScopeId, SemanticIndex};
|
||||
|
||||
pub(super) struct SemanticIndexBuilder<'a> {
|
||||
// Builder state
|
||||
module: &'a ParsedModule,
|
||||
scope_stack: Vec<FileScopeId>,
|
||||
/// the definition whose target(s) we are currently walking
|
||||
current_definition: Option<Definition>,
|
||||
|
||||
// Semantic Index fields
|
||||
scopes: IndexVec<FileScopeId, Scope>,
|
||||
symbol_tables: IndexVec<FileScopeId, SymbolTableBuilder>,
|
||||
ast_ids: IndexVec<FileScopeId, AstIdsBuilder>,
|
||||
expression_scopes: FxHashMap<NodeKey, FileScopeId>,
|
||||
scope_nodes: IndexVec<FileScopeId, NodeWithScopeId>,
|
||||
}
|
||||
|
||||
impl<'a> SemanticIndexBuilder<'a> {
|
||||
pub(super) fn new(parsed: &'a ParsedModule) -> Self {
|
||||
let mut builder = Self {
|
||||
module: parsed,
|
||||
scope_stack: Vec::new(),
|
||||
current_definition: None,
|
||||
|
||||
scopes: IndexVec::new(),
|
||||
symbol_tables: IndexVec::new(),
|
||||
ast_ids: IndexVec::new(),
|
||||
expression_scopes: FxHashMap::default(),
|
||||
scope_nodes: IndexVec::new(),
|
||||
};
|
||||
|
||||
builder.push_scope_with_parent(
|
||||
NodeWithScopeId::Module,
|
||||
&Name::new_static("<module>"),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
builder
|
||||
}
|
||||
|
||||
fn current_scope(&self) -> FileScopeId {
|
||||
*self
|
||||
.scope_stack
|
||||
.last()
|
||||
.expect("Always to have a root scope")
|
||||
}
|
||||
|
||||
fn push_scope(
|
||||
&mut self,
|
||||
node: NodeWithScopeId,
|
||||
name: &Name,
|
||||
defining_symbol: Option<FileSymbolId>,
|
||||
definition: Option<Definition>,
|
||||
) {
|
||||
let parent = self.current_scope();
|
||||
self.push_scope_with_parent(node, name, defining_symbol, definition, Some(parent));
|
||||
}
|
||||
|
||||
fn push_scope_with_parent(
|
||||
&mut self,
|
||||
node: NodeWithScopeId,
|
||||
name: &Name,
|
||||
defining_symbol: Option<FileSymbolId>,
|
||||
definition: Option<Definition>,
|
||||
parent: Option<FileScopeId>,
|
||||
) {
|
||||
let children_start = self.scopes.next_index() + 1;
|
||||
|
||||
let scope = Scope {
|
||||
name: name.clone(),
|
||||
parent,
|
||||
defining_symbol,
|
||||
definition,
|
||||
kind: node.scope_kind(),
|
||||
descendents: children_start..children_start,
|
||||
};
|
||||
|
||||
let scope_id = self.scopes.push(scope);
|
||||
self.symbol_tables.push(SymbolTableBuilder::new());
|
||||
let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new());
|
||||
let scope_node_id = self.scope_nodes.push(node);
|
||||
|
||||
debug_assert_eq!(ast_id_scope, scope_id);
|
||||
debug_assert_eq!(scope_id, scope_node_id);
|
||||
self.scope_stack.push(scope_id);
|
||||
}
|
||||
|
||||
fn pop_scope(&mut self) -> FileScopeId {
|
||||
let id = self.scope_stack.pop().expect("Root scope to be present");
|
||||
let children_end = self.scopes.next_index();
|
||||
let scope = &mut self.scopes[id];
|
||||
scope.descendents = scope.descendents.start..children_end;
|
||||
id
|
||||
}
|
||||
|
||||
fn current_symbol_table(&mut self) -> &mut SymbolTableBuilder {
|
||||
let scope_id = self.current_scope();
|
||||
&mut self.symbol_tables[scope_id]
|
||||
}
|
||||
|
||||
fn current_ast_ids(&mut self) -> &mut AstIdsBuilder {
|
||||
let scope_id = self.current_scope();
|
||||
&mut self.ast_ids[scope_id]
|
||||
}
|
||||
|
||||
fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId {
|
||||
let symbol_table = self.current_symbol_table();
|
||||
|
||||
symbol_table.add_or_update_symbol(name, flags, None)
|
||||
}
|
||||
|
||||
fn add_or_update_symbol_with_definition(
|
||||
&mut self,
|
||||
name: Name,
|
||||
|
||||
definition: Definition,
|
||||
) -> ScopedSymbolId {
|
||||
let symbol_table = self.current_symbol_table();
|
||||
|
||||
symbol_table.add_or_update_symbol(name, SymbolFlags::IS_DEFINED, Some(definition))
|
||||
}
|
||||
|
||||
fn with_type_params(
|
||||
&mut self,
|
||||
name: &Name,
|
||||
with_params: &WithTypeParams,
|
||||
defining_symbol: FileSymbolId,
|
||||
nested: impl FnOnce(&mut Self) -> FileScopeId,
|
||||
) -> FileScopeId {
|
||||
let type_params = with_params.type_parameters();
|
||||
|
||||
if let Some(type_params) = type_params {
|
||||
let type_node = match with_params {
|
||||
WithTypeParams::ClassDef { id, .. } => NodeWithScopeId::ClassTypeParams(*id),
|
||||
WithTypeParams::FunctionDef { id, .. } => NodeWithScopeId::FunctionTypeParams(*id),
|
||||
};
|
||||
|
||||
self.push_scope(
|
||||
type_node,
|
||||
name,
|
||||
Some(defining_symbol),
|
||||
Some(with_params.definition()),
|
||||
);
|
||||
for type_param in &type_params.type_params {
|
||||
let name = match type_param {
|
||||
ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name,
|
||||
ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, .. }) => name,
|
||||
ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, .. }) => name,
|
||||
};
|
||||
self.add_or_update_symbol(Name::new(name), SymbolFlags::IS_DEFINED);
|
||||
}
|
||||
}
|
||||
|
||||
let nested_scope = nested(self);
|
||||
|
||||
if type_params.is_some() {
|
||||
self.pop_scope();
|
||||
}
|
||||
|
||||
nested_scope
|
||||
}
|
||||
|
||||
pub(super) fn build(mut self) -> SemanticIndex {
|
||||
let module = self.module;
|
||||
self.visit_body(module.suite());
|
||||
|
||||
// Pop the root scope
|
||||
self.pop_scope();
|
||||
assert!(self.scope_stack.is_empty());
|
||||
|
||||
assert!(self.current_definition.is_none());
|
||||
|
||||
let mut symbol_tables: IndexVec<_, _> = self
|
||||
.symbol_tables
|
||||
.into_iter()
|
||||
.map(|builder| Arc::new(builder.finish()))
|
||||
.collect();
|
||||
|
||||
let mut ast_ids: IndexVec<_, _> = self
|
||||
.ast_ids
|
||||
.into_iter()
|
||||
.map(super::ast_ids::AstIdsBuilder::finish)
|
||||
.collect();
|
||||
|
||||
self.scopes.shrink_to_fit();
|
||||
ast_ids.shrink_to_fit();
|
||||
symbol_tables.shrink_to_fit();
|
||||
self.expression_scopes.shrink_to_fit();
|
||||
self.scope_nodes.shrink_to_fit();
|
||||
|
||||
SemanticIndex {
|
||||
symbol_tables,
|
||||
scopes: self.scopes,
|
||||
scope_nodes: self.scope_nodes,
|
||||
ast_ids,
|
||||
expression_scopes: self.expression_scopes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Visitor<'_> for SemanticIndexBuilder<'_> {
|
||||
fn visit_stmt(&mut self, stmt: &ast::Stmt) {
|
||||
let module = self.module;
|
||||
#[allow(unsafe_code)]
|
||||
let statement_id = unsafe {
|
||||
// SAFETY: The builder only visits nodes that are part of `module`. This guarantees that
|
||||
// the current statement must be a child of `module`.
|
||||
self.current_ast_ids().record_statement(stmt, module)
|
||||
};
|
||||
match stmt {
|
||||
ast::Stmt::FunctionDef(function_def) => {
|
||||
for decorator in &function_def.decorator_list {
|
||||
self.visit_decorator(decorator);
|
||||
}
|
||||
let name = Name::new(&function_def.name.id);
|
||||
let function_id = ScopeFunctionId(statement_id);
|
||||
let definition = Definition::FunctionDef(function_id);
|
||||
let scope = self.current_scope();
|
||||
let symbol = FileSymbolId::new(
|
||||
scope,
|
||||
self.add_or_update_symbol_with_definition(name.clone(), definition),
|
||||
);
|
||||
|
||||
self.with_type_params(
|
||||
&name,
|
||||
&WithTypeParams::FunctionDef {
|
||||
node: function_def,
|
||||
id: AstId::new(scope, function_id),
|
||||
},
|
||||
symbol,
|
||||
|builder| {
|
||||
builder.visit_parameters(&function_def.parameters);
|
||||
for expr in &function_def.returns {
|
||||
builder.visit_annotation(expr);
|
||||
}
|
||||
|
||||
builder.push_scope(
|
||||
NodeWithScopeId::Function(AstId::new(scope, function_id)),
|
||||
&name,
|
||||
Some(symbol),
|
||||
Some(definition),
|
||||
);
|
||||
builder.visit_body(&function_def.body);
|
||||
builder.pop_scope()
|
||||
},
|
||||
);
|
||||
}
|
||||
ast::Stmt::ClassDef(class) => {
|
||||
for decorator in &class.decorator_list {
|
||||
self.visit_decorator(decorator);
|
||||
}
|
||||
|
||||
let name = Name::new(&class.name.id);
|
||||
let class_id = ScopeClassId(statement_id);
|
||||
let definition = Definition::from(class_id);
|
||||
let scope = self.current_scope();
|
||||
let id = FileSymbolId::new(
|
||||
self.current_scope(),
|
||||
self.add_or_update_symbol_with_definition(name.clone(), definition),
|
||||
);
|
||||
self.with_type_params(
|
||||
&name,
|
||||
&WithTypeParams::ClassDef {
|
||||
node: class,
|
||||
id: AstId::new(scope, class_id),
|
||||
},
|
||||
id,
|
||||
|builder| {
|
||||
if let Some(arguments) = &class.arguments {
|
||||
builder.visit_arguments(arguments);
|
||||
}
|
||||
|
||||
builder.push_scope(
|
||||
NodeWithScopeId::Class(AstId::new(scope, class_id)),
|
||||
&name,
|
||||
Some(id),
|
||||
Some(definition),
|
||||
);
|
||||
builder.visit_body(&class.body);
|
||||
|
||||
builder.pop_scope()
|
||||
},
|
||||
);
|
||||
}
|
||||
ast::Stmt::Import(ast::StmtImport { names, .. }) => {
|
||||
for (i, alias) in names.iter().enumerate() {
|
||||
let symbol_name = if let Some(asname) = &alias.asname {
|
||||
asname.id.as_str()
|
||||
} else {
|
||||
alias.name.id.split('.').next().unwrap()
|
||||
};
|
||||
|
||||
let def = Definition::Import(ImportDefinition {
|
||||
import_id: ScopeImportId(statement_id),
|
||||
alias: u32::try_from(i).unwrap(),
|
||||
});
|
||||
self.add_or_update_symbol_with_definition(Name::new(symbol_name), def);
|
||||
}
|
||||
}
|
||||
ast::Stmt::ImportFrom(ast::StmtImportFrom {
|
||||
module: _,
|
||||
names,
|
||||
level: _,
|
||||
..
|
||||
}) => {
|
||||
for (i, alias) in names.iter().enumerate() {
|
||||
let symbol_name = if let Some(asname) = &alias.asname {
|
||||
asname.id.as_str()
|
||||
} else {
|
||||
alias.name.id.as_str()
|
||||
};
|
||||
let def = Definition::ImportFrom(ImportFromDefinition {
|
||||
import_id: ScopeImportFromId(statement_id),
|
||||
name: u32::try_from(i).unwrap(),
|
||||
});
|
||||
self.add_or_update_symbol_with_definition(Name::new(symbol_name), def);
|
||||
}
|
||||
}
|
||||
ast::Stmt::Assign(node) => {
|
||||
debug_assert!(self.current_definition.is_none());
|
||||
self.visit_expr(&node.value);
|
||||
self.current_definition =
|
||||
Some(Definition::Assignment(ScopeAssignmentId(statement_id)));
|
||||
for target in &node.targets {
|
||||
self.visit_expr(target);
|
||||
}
|
||||
self.current_definition = None;
|
||||
}
|
||||
_ => {
|
||||
walk_stmt(self, stmt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_expr(&mut self, expr: &'_ ast::Expr) {
|
||||
let module = self.module;
|
||||
#[allow(unsafe_code)]
|
||||
let expression_id = unsafe {
|
||||
// SAFETY: The builder only visits nodes that are part of `module`. This guarantees that
|
||||
// the current expression must be a child of `module`.
|
||||
self.current_ast_ids().record_expression(expr, module)
|
||||
};
|
||||
|
||||
self.expression_scopes
|
||||
.insert(NodeKey::from_node(expr), self.current_scope());
|
||||
|
||||
match expr {
|
||||
ast::Expr::Name(ast::ExprName { id, ctx, .. }) => {
|
||||
let flags = match ctx {
|
||||
ast::ExprContext::Load => SymbolFlags::IS_USED,
|
||||
ast::ExprContext::Store => SymbolFlags::IS_DEFINED,
|
||||
ast::ExprContext::Del => SymbolFlags::IS_DEFINED,
|
||||
ast::ExprContext::Invalid => SymbolFlags::empty(),
|
||||
};
|
||||
match self.current_definition {
|
||||
Some(definition) if flags.contains(SymbolFlags::IS_DEFINED) => {
|
||||
self.add_or_update_symbol_with_definition(Name::new(id), definition);
|
||||
}
|
||||
_ => {
|
||||
self.add_or_update_symbol(Name::new(id), flags);
|
||||
}
|
||||
}
|
||||
|
||||
walk_expr(self, expr);
|
||||
}
|
||||
ast::Expr::Named(node) => {
|
||||
debug_assert!(self.current_definition.is_none());
|
||||
self.current_definition =
|
||||
Some(Definition::NamedExpr(ScopeNamedExprId(expression_id)));
|
||||
// TODO walrus in comprehensions is implicitly nonlocal
|
||||
self.visit_expr(&node.target);
|
||||
self.current_definition = None;
|
||||
self.visit_expr(&node.value);
|
||||
}
|
||||
ast::Expr::If(ast::ExprIf {
|
||||
body, test, orelse, ..
|
||||
}) => {
|
||||
// TODO detect statically known truthy or falsy test (via type inference, not naive
|
||||
// AST inspection, so we can't simplify here, need to record test expression in CFG
|
||||
// for later checking)
|
||||
|
||||
self.visit_expr(test);
|
||||
|
||||
// let if_branch = self.flow_graph_builder.add_branch(self.current_flow_node());
|
||||
|
||||
// self.set_current_flow_node(if_branch);
|
||||
// self.insert_constraint(test);
|
||||
self.visit_expr(body);
|
||||
|
||||
// let post_body = self.current_flow_node();
|
||||
|
||||
// self.set_current_flow_node(if_branch);
|
||||
self.visit_expr(orelse);
|
||||
|
||||
// let post_else = self
|
||||
// .flow_graph_builder
|
||||
// .add_phi(self.current_flow_node(), post_body);
|
||||
|
||||
// self.set_current_flow_node(post_else);
|
||||
}
|
||||
_ => {
|
||||
walk_expr(self, expr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum WithTypeParams<'a> {
|
||||
ClassDef {
|
||||
node: &'a ast::StmtClassDef,
|
||||
id: AstId<ScopeClassId>,
|
||||
},
|
||||
FunctionDef {
|
||||
node: &'a ast::StmtFunctionDef,
|
||||
id: AstId<ScopeFunctionId>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<'a> WithTypeParams<'a> {
|
||||
fn type_parameters(&self) -> Option<&'a ast::TypeParams> {
|
||||
match self {
|
||||
WithTypeParams::ClassDef { node, .. } => node.type_params.as_deref(),
|
||||
WithTypeParams::FunctionDef { node, .. } => node.type_params.as_deref(),
|
||||
}
|
||||
}
|
||||
|
||||
fn definition(&self) -> Definition {
|
||||
match self {
|
||||
WithTypeParams::ClassDef { id, .. } => Definition::ClassDef(id.in_scope_id()),
|
||||
WithTypeParams::FunctionDef { id, .. } => Definition::FunctionDef(id.in_scope_id()),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
use crate::semantic_index::ast_ids::{
|
||||
ScopeAnnotatedAssignmentId, ScopeAssignmentId, ScopeClassId, ScopeFunctionId,
|
||||
ScopeImportFromId, ScopeImportId, ScopeNamedExprId,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum Definition {
|
||||
Import(ImportDefinition),
|
||||
ImportFrom(ImportFromDefinition),
|
||||
ClassDef(ScopeClassId),
|
||||
FunctionDef(ScopeFunctionId),
|
||||
Assignment(ScopeAssignmentId),
|
||||
AnnotatedAssignment(ScopeAnnotatedAssignmentId),
|
||||
NamedExpr(ScopeNamedExprId),
|
||||
/// represents the implicit initial definition of every name as "unbound"
|
||||
Unbound,
|
||||
// TODO with statements, except handlers, function args...
|
||||
}
|
||||
|
||||
impl From<ImportDefinition> for Definition {
|
||||
fn from(value: ImportDefinition) -> Self {
|
||||
Self::Import(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ImportFromDefinition> for Definition {
|
||||
fn from(value: ImportFromDefinition) -> Self {
|
||||
Self::ImportFrom(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ScopeClassId> for Definition {
|
||||
fn from(value: ScopeClassId) -> Self {
|
||||
Self::ClassDef(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ScopeFunctionId> for Definition {
|
||||
fn from(value: ScopeFunctionId) -> Self {
|
||||
Self::FunctionDef(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ScopeAssignmentId> for Definition {
|
||||
fn from(value: ScopeAssignmentId) -> Self {
|
||||
Self::Assignment(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ScopeAnnotatedAssignmentId> for Definition {
|
||||
fn from(value: ScopeAnnotatedAssignmentId) -> Self {
|
||||
Self::AnnotatedAssignment(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ScopeNamedExprId> for Definition {
|
||||
fn from(value: ScopeNamedExprId) -> Self {
|
||||
Self::NamedExpr(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct ImportDefinition {
|
||||
pub(crate) import_id: ScopeImportId,
|
||||
|
||||
/// Index into [`ruff_python_ast::StmtImport::names`].
|
||||
pub(crate) alias: u32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct ImportFromDefinition {
|
||||
pub(crate) import_id: ScopeImportFromId,
|
||||
|
||||
/// Index into [`ruff_python_ast::StmtImportFrom::names`].
|
||||
pub(crate) name: u32,
|
||||
}
|
379
crates/red_knot_python_semantic/src/semantic_index/symbol.rs
Normal file
379
crates/red_knot_python_semantic/src/semantic_index/symbol.rs
Normal file
|
@ -0,0 +1,379 @@
|
|||
// Allow unused underscore violations generated by the salsa macro
|
||||
// TODO(micha): Contribute fix upstream
|
||||
#![allow(clippy::used_underscore_binding)]
|
||||
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::ops::Range;
|
||||
|
||||
use bitflags::bitflags;
|
||||
use hashbrown::hash_map::RawEntryMut;
|
||||
use rustc_hash::FxHasher;
|
||||
use smallvec::SmallVec;
|
||||
|
||||
use ruff_db::vfs::VfsFile;
|
||||
use ruff_index::{newtype_index, IndexVec};
|
||||
|
||||
use crate::name::Name;
|
||||
use crate::semantic_index::definition::Definition;
|
||||
use crate::semantic_index::{root_scope, semantic_index, symbol_table, SymbolMap};
|
||||
use crate::Db;
|
||||
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
pub struct Symbol {
|
||||
name: Name,
|
||||
flags: SymbolFlags,
|
||||
/// The nodes that define this symbol, in source order.
|
||||
definitions: SmallVec<[Definition; 4]>,
|
||||
}
|
||||
|
||||
impl Symbol {
|
||||
fn new(name: Name, definition: Option<Definition>) -> Self {
|
||||
Self {
|
||||
name,
|
||||
flags: SymbolFlags::empty(),
|
||||
definitions: definition.into_iter().collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn push_definition(&mut self, definition: Definition) {
|
||||
self.definitions.push(definition);
|
||||
}
|
||||
|
||||
fn insert_flags(&mut self, flags: SymbolFlags) {
|
||||
self.flags.insert(flags);
|
||||
}
|
||||
|
||||
/// The symbol's name.
|
||||
pub fn name(&self) -> &Name {
|
||||
&self.name
|
||||
}
|
||||
|
||||
/// Is the symbol used in its containing scope?
|
||||
pub fn is_used(&self) -> bool {
|
||||
self.flags.contains(SymbolFlags::IS_USED)
|
||||
}
|
||||
|
||||
/// Is the symbol defined in its containing scope?
|
||||
pub fn is_defined(&self) -> bool {
|
||||
self.flags.contains(SymbolFlags::IS_DEFINED)
|
||||
}
|
||||
|
||||
pub fn definitions(&self) -> &[Definition] {
|
||||
&self.definitions
|
||||
}
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub(super) struct SymbolFlags: u8 {
|
||||
const IS_USED = 1 << 0;
|
||||
const IS_DEFINED = 1 << 1;
|
||||
/// TODO: This flag is not yet set by anything
|
||||
const MARKED_GLOBAL = 1 << 2;
|
||||
/// TODO: This flag is not yet set by anything
|
||||
const MARKED_NONLOCAL = 1 << 3;
|
||||
}
|
||||
}
|
||||
|
||||
/// ID that uniquely identifies a public symbol defined in a module's root scope.
|
||||
#[salsa::tracked]
|
||||
pub struct PublicSymbolId {
|
||||
#[id]
|
||||
pub(crate) file: VfsFile,
|
||||
#[id]
|
||||
pub(crate) scoped_symbol_id: ScopedSymbolId,
|
||||
}
|
||||
|
||||
/// ID that uniquely identifies a symbol in a file.
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct FileSymbolId {
|
||||
scope: FileScopeId,
|
||||
scoped_symbol_id: ScopedSymbolId,
|
||||
}
|
||||
|
||||
impl FileSymbolId {
|
||||
pub(super) fn new(scope: FileScopeId, symbol: ScopedSymbolId) -> Self {
|
||||
Self {
|
||||
scope,
|
||||
scoped_symbol_id: symbol,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scope(self) -> FileScopeId {
|
||||
self.scope
|
||||
}
|
||||
|
||||
pub(crate) fn scoped_symbol_id(self) -> ScopedSymbolId {
|
||||
self.scoped_symbol_id
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FileSymbolId> for ScopedSymbolId {
|
||||
fn from(val: FileSymbolId) -> Self {
|
||||
val.scoped_symbol_id()
|
||||
}
|
||||
}
|
||||
|
||||
/// Symbol ID that uniquely identifies a symbol inside a [`Scope`].
|
||||
#[newtype_index]
|
||||
pub struct ScopedSymbolId;
|
||||
|
||||
impl ScopedSymbolId {
|
||||
/// Converts the symbol to a public symbol.
|
||||
///
|
||||
/// # Panics
|
||||
/// May panic if the symbol does not belong to `file` or is not a symbol of `file`'s root scope.
|
||||
pub(crate) fn to_public_symbol(self, db: &dyn Db, file: VfsFile) -> PublicSymbolId {
|
||||
let symbols = public_symbols_map(db, file);
|
||||
symbols.public(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a mapping from [`FileScopeId`] to globally unique [`ScopeId`].
|
||||
#[salsa::tracked(return_ref)]
|
||||
pub(crate) fn scopes_map(db: &dyn Db, file: VfsFile) -> ScopesMap {
|
||||
let index = semantic_index(db, file);
|
||||
|
||||
let scopes: IndexVec<_, _> = index
|
||||
.scopes
|
||||
.indices()
|
||||
.map(|id| ScopeId::new(db, file, id))
|
||||
.collect();
|
||||
|
||||
ScopesMap { scopes }
|
||||
}
|
||||
|
||||
/// Maps from the file specific [`FileScopeId`] to the global [`ScopeId`] that can be used as a Salsa query parameter.
|
||||
///
|
||||
/// The [`SemanticIndex`] uses [`FileScopeId`] on a per-file level to identify scopes
|
||||
/// because they allow for more efficient storage of associated data
|
||||
/// (use of an [`IndexVec`] keyed by [`FileScopeId`] over an [`FxHashMap`] keyed by [`ScopeId`]).
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
pub(crate) struct ScopesMap {
|
||||
scopes: IndexVec<FileScopeId, ScopeId>,
|
||||
}
|
||||
|
||||
impl ScopesMap {
|
||||
/// Gets the program-wide unique scope id for the given file specific `scope_id`.
|
||||
fn get(&self, scope: FileScopeId) -> ScopeId {
|
||||
self.scopes[scope]
|
||||
}
|
||||
}
|
||||
|
||||
#[salsa::tracked(return_ref)]
|
||||
pub(crate) fn public_symbols_map(db: &dyn Db, file: VfsFile) -> PublicSymbolsMap {
|
||||
let module_scope = root_scope(db, file);
|
||||
let symbols = symbol_table(db, module_scope);
|
||||
|
||||
let public_symbols: IndexVec<_, _> = symbols
|
||||
.symbol_ids()
|
||||
.map(|id| PublicSymbolId::new(db, file, id))
|
||||
.collect();
|
||||
|
||||
PublicSymbolsMap {
|
||||
symbols: public_symbols,
|
||||
}
|
||||
}
|
||||
|
||||
/// Maps [`LocalSymbolId`] of a file's root scope to the corresponding [`PublicSymbolId`] (Salsa ingredients).
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
pub(crate) struct PublicSymbolsMap {
|
||||
symbols: IndexVec<ScopedSymbolId, PublicSymbolId>,
|
||||
}
|
||||
|
||||
impl PublicSymbolsMap {
|
||||
/// Resolve the [`PublicSymbolId`] for the module-level `symbol_id`.
|
||||
fn public(&self, symbol_id: ScopedSymbolId) -> PublicSymbolId {
|
||||
self.symbols[symbol_id]
|
||||
}
|
||||
}
|
||||
|
||||
/// A cross-module identifier of a scope that can be used as a salsa query parameter.
|
||||
#[salsa::tracked]
|
||||
pub struct ScopeId {
|
||||
#[allow(clippy::used_underscore_binding)]
|
||||
#[id]
|
||||
pub file: VfsFile,
|
||||
#[id]
|
||||
pub file_scope_id: FileScopeId,
|
||||
}
|
||||
|
||||
/// ID that uniquely identifies a scope inside of a module.
|
||||
#[newtype_index]
|
||||
pub struct FileScopeId;
|
||||
|
||||
impl FileScopeId {
|
||||
/// Returns the scope id of the Root scope.
|
||||
pub fn root() -> Self {
|
||||
FileScopeId::from_u32(0)
|
||||
}
|
||||
|
||||
pub fn to_scope_id(self, db: &dyn Db, file: VfsFile) -> ScopeId {
|
||||
scopes_map(db, file).get(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
pub struct Scope {
|
||||
pub(super) name: Name,
|
||||
pub(super) parent: Option<FileScopeId>,
|
||||
pub(super) definition: Option<Definition>,
|
||||
pub(super) defining_symbol: Option<FileSymbolId>,
|
||||
pub(super) kind: ScopeKind,
|
||||
pub(super) descendents: Range<FileScopeId>,
|
||||
}
|
||||
|
||||
impl Scope {
|
||||
pub fn name(&self) -> &Name {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn definition(&self) -> Option<Definition> {
|
||||
self.definition
|
||||
}
|
||||
|
||||
pub fn defining_symbol(&self) -> Option<FileSymbolId> {
|
||||
self.defining_symbol
|
||||
}
|
||||
|
||||
pub fn parent(self) -> Option<FileScopeId> {
|
||||
self.parent
|
||||
}
|
||||
|
||||
pub fn kind(&self) -> ScopeKind {
|
||||
self.kind
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum ScopeKind {
|
||||
Module,
|
||||
Annotation,
|
||||
Class,
|
||||
Function,
|
||||
}
|
||||
|
||||
/// Symbol table for a specific [`Scope`].
|
||||
#[derive(Debug)]
|
||||
pub struct SymbolTable {
|
||||
/// The symbols in this scope.
|
||||
symbols: IndexVec<ScopedSymbolId, Symbol>,
|
||||
|
||||
/// The symbols indexed by name.
|
||||
symbols_by_name: SymbolMap,
|
||||
}
|
||||
|
||||
impl SymbolTable {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
symbols: IndexVec::new(),
|
||||
symbols_by_name: SymbolMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn shrink_to_fit(&mut self) {
|
||||
self.symbols.shrink_to_fit();
|
||||
}
|
||||
|
||||
pub(crate) fn symbol(&self, symbol_id: impl Into<ScopedSymbolId>) -> &Symbol {
|
||||
&self.symbols[symbol_id.into()]
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub(crate) fn symbol_ids(&self) -> impl Iterator<Item = ScopedSymbolId> {
|
||||
self.symbols.indices()
|
||||
}
|
||||
|
||||
pub fn symbols(&self) -> impl Iterator<Item = &Symbol> {
|
||||
self.symbols.iter()
|
||||
}
|
||||
|
||||
/// Returns the symbol named `name`.
|
||||
#[allow(unused)]
|
||||
pub(crate) fn symbol_by_name(&self, name: &str) -> Option<&Symbol> {
|
||||
let id = self.symbol_id_by_name(name)?;
|
||||
Some(self.symbol(id))
|
||||
}
|
||||
|
||||
/// Returns the [`ScopedSymbolId`] of the symbol named `name`.
|
||||
pub(crate) fn symbol_id_by_name(&self, name: &str) -> Option<ScopedSymbolId> {
|
||||
let (id, ()) = self
|
||||
.symbols_by_name
|
||||
.raw_entry()
|
||||
.from_hash(Self::hash_name(name), |id| {
|
||||
self.symbol(*id).name().as_str() == name
|
||||
})?;
|
||||
|
||||
Some(*id)
|
||||
}
|
||||
|
||||
fn hash_name(name: &str) -> u64 {
|
||||
let mut hasher = FxHasher::default();
|
||||
name.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for SymbolTable {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
// We don't need to compare the symbols_by_name because the name is already captured in `Symbol`.
|
||||
self.symbols == other.symbols
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for SymbolTable {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct SymbolTableBuilder {
|
||||
table: SymbolTable,
|
||||
}
|
||||
|
||||
impl SymbolTableBuilder {
|
||||
pub(super) fn new() -> Self {
|
||||
Self {
|
||||
table: SymbolTable::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn add_or_update_symbol(
|
||||
&mut self,
|
||||
name: Name,
|
||||
flags: SymbolFlags,
|
||||
definition: Option<Definition>,
|
||||
) -> ScopedSymbolId {
|
||||
let hash = SymbolTable::hash_name(&name);
|
||||
let entry = self
|
||||
.table
|
||||
.symbols_by_name
|
||||
.raw_entry_mut()
|
||||
.from_hash(hash, |id| self.table.symbols[*id].name() == &name);
|
||||
|
||||
match entry {
|
||||
RawEntryMut::Occupied(entry) => {
|
||||
let symbol = &mut self.table.symbols[*entry.key()];
|
||||
symbol.insert_flags(flags);
|
||||
|
||||
if let Some(definition) = definition {
|
||||
symbol.push_definition(definition);
|
||||
}
|
||||
|
||||
*entry.key()
|
||||
}
|
||||
RawEntryMut::Vacant(entry) => {
|
||||
let mut symbol = Symbol::new(name, definition);
|
||||
symbol.insert_flags(flags);
|
||||
|
||||
let id = self.table.symbols.push(symbol);
|
||||
entry.insert_with_hasher(hash, id, (), |id| {
|
||||
SymbolTable::hash_name(self.table.symbols[*id].name().as_str())
|
||||
});
|
||||
id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn finish(mut self) -> SymbolTable {
|
||||
self.table.shrink_to_fit();
|
||||
self.table
|
||||
}
|
||||
}
|
682
crates/red_knot_python_semantic/src/types.rs
Normal file
682
crates/red_knot_python_semantic/src/types.rs
Normal file
|
@ -0,0 +1,682 @@
|
|||
use salsa::DebugWithDb;
|
||||
|
||||
use ruff_db::parsed::parsed_module;
|
||||
use ruff_db::vfs::VfsFile;
|
||||
use ruff_index::newtype_index;
|
||||
use ruff_python_ast as ast;
|
||||
|
||||
use crate::name::Name;
|
||||
use crate::semantic_index::ast_ids::{AstIdNode, ScopeAstIdNode};
|
||||
use crate::semantic_index::symbol::{FileScopeId, PublicSymbolId, ScopeId};
|
||||
use crate::semantic_index::{
|
||||
public_symbol, root_scope, semantic_index, symbol_table, NodeWithScopeId,
|
||||
};
|
||||
use crate::types::infer::{TypeInference, TypeInferenceBuilder};
|
||||
use crate::Db;
|
||||
use crate::FxIndexSet;
|
||||
|
||||
mod display;
|
||||
mod infer;
|
||||
|
||||
/// Infers the type of `expr`.
|
||||
///
|
||||
/// Calling this function from a salsa query adds a dependency on [`semantic_index`]
|
||||
/// which changes with every AST change. That's why you should only call
|
||||
/// this function for the current file that's being analyzed and not for
|
||||
/// a dependency (or the query reruns whenever a dependency change).
|
||||
///
|
||||
/// Prefer [`public_symbol_ty`] when resolving the type of symbol from another file.
|
||||
#[tracing::instrument(level = "debug", skip(db))]
|
||||
pub(crate) fn expression_ty(db: &dyn Db, file: VfsFile, expression: &ast::Expr) -> Type {
|
||||
let index = semantic_index(db, file);
|
||||
let file_scope = index.expression_scope_id(expression);
|
||||
let expression_id = expression.scope_ast_id(db, file, file_scope);
|
||||
let scope = file_scope.to_scope_id(db, file);
|
||||
|
||||
infer_types(db, scope).expression_ty(expression_id)
|
||||
}
|
||||
|
||||
/// Infers the type of a public symbol.
|
||||
///
|
||||
/// This is a Salsa query to get symbol-level invalidation instead of file-level dependency invalidation.
|
||||
/// Without this being a query, changing any public type of a module would invalidate the type inference
|
||||
/// for the module scope of its dependents and the transitive dependents because.
|
||||
///
|
||||
/// For example if we have
|
||||
/// ```python
|
||||
/// # a.py
|
||||
/// import x from b
|
||||
///
|
||||
/// # b.py
|
||||
///
|
||||
/// x = 20
|
||||
/// ```
|
||||
///
|
||||
/// And x is now changed from `x = 20` to `x = 30`. The following happens:
|
||||
///
|
||||
/// * The module level types of `b.py` change because `x` now is a `Literal[30]`.
|
||||
/// * The module level types of `a.py` change because the imported symbol `x` now has a `Literal[30]` type
|
||||
/// * The module level types of any dependents of `a.py` change because the imported symbol `x` now has a `Literal[30]` type
|
||||
/// * And so on for all transitive dependencies.
|
||||
///
|
||||
/// This being a query ensures that the invalidation short-circuits if the type of this symbol didn't change.
|
||||
#[salsa::tracked]
|
||||
pub(crate) fn public_symbol_ty(db: &dyn Db, symbol: PublicSymbolId) -> Type {
|
||||
let _ = tracing::debug_span!("public_symbol_ty", symbol = ?symbol.debug(db)).enter();
|
||||
|
||||
let file = symbol.file(db);
|
||||
let scope = root_scope(db, file);
|
||||
|
||||
let inference = infer_types(db, scope);
|
||||
inference.symbol_ty(symbol.scoped_symbol_id(db))
|
||||
}
|
||||
|
||||
/// Shorthand for `public_symbol_ty` that takes a symbol name instead of a [`PublicSymbolId`].
|
||||
pub fn public_symbol_ty_by_name(db: &dyn Db, file: VfsFile, name: &str) -> Option<Type> {
|
||||
let symbol = public_symbol(db, file, name)?;
|
||||
Some(public_symbol_ty(db, symbol))
|
||||
}
|
||||
|
||||
/// Infers all types for `scope`.
|
||||
#[salsa::tracked(return_ref)]
|
||||
pub(crate) fn infer_types(db: &dyn Db, scope: ScopeId) -> TypeInference {
|
||||
let file = scope.file(db);
|
||||
// Using the index here is fine because the code below depends on the AST anyway.
|
||||
// The isolation of the query is by the return inferred types.
|
||||
let index = semantic_index(db, file);
|
||||
|
||||
let scope_id = scope.file_scope_id(db);
|
||||
let node = index.scope_node(scope_id);
|
||||
|
||||
let mut context = TypeInferenceBuilder::new(db, scope, index);
|
||||
|
||||
match node {
|
||||
NodeWithScopeId::Module => {
|
||||
let parsed = parsed_module(db.upcast(), file);
|
||||
context.infer_module(parsed.syntax());
|
||||
}
|
||||
NodeWithScopeId::Class(class_id) => {
|
||||
let class = ast::StmtClassDef::lookup(db, file, class_id);
|
||||
context.infer_class_body(class);
|
||||
}
|
||||
NodeWithScopeId::ClassTypeParams(class_id) => {
|
||||
let class = ast::StmtClassDef::lookup(db, file, class_id);
|
||||
context.infer_class_type_params(class);
|
||||
}
|
||||
NodeWithScopeId::Function(function_id) => {
|
||||
let function = ast::StmtFunctionDef::lookup(db, file, function_id);
|
||||
context.infer_function_body(function);
|
||||
}
|
||||
NodeWithScopeId::FunctionTypeParams(function_id) => {
|
||||
let function = ast::StmtFunctionDef::lookup(db, file, function_id);
|
||||
context.infer_function_type_params(function);
|
||||
}
|
||||
}
|
||||
|
||||
context.finish()
|
||||
}
|
||||
|
||||
/// unique ID for a type
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum Type {
|
||||
/// the dynamic type: a statically-unknown set of values
|
||||
Any,
|
||||
/// the empty set of values
|
||||
Never,
|
||||
/// unknown type (no annotation)
|
||||
/// equivalent to Any, or to object in strict mode
|
||||
Unknown,
|
||||
/// name is not bound to any value
|
||||
Unbound,
|
||||
/// the None object (TODO remove this in favor of Instance(types.NoneType)
|
||||
None,
|
||||
/// a specific function object
|
||||
Function(TypeId<ScopedFunctionTypeId>),
|
||||
/// a specific module object
|
||||
Module(TypeId<ScopedModuleTypeId>),
|
||||
/// a specific class object
|
||||
Class(TypeId<ScopedClassTypeId>),
|
||||
/// the set of Python objects with the given class in their __class__'s method resolution order
|
||||
Instance(TypeId<ScopedClassTypeId>),
|
||||
Union(TypeId<ScopedUnionTypeId>),
|
||||
Intersection(TypeId<ScopedIntersectionTypeId>),
|
||||
IntLiteral(i64),
|
||||
// TODO protocols, callable types, overloads, generics, type vars
|
||||
}
|
||||
|
||||
impl Type {
|
||||
pub const fn is_unbound(&self) -> bool {
|
||||
matches!(self, Type::Unbound)
|
||||
}
|
||||
|
||||
pub const fn is_unknown(&self) -> bool {
|
||||
matches!(self, Type::Unknown)
|
||||
}
|
||||
|
||||
pub fn member(&self, context: &TypingContext, name: &Name) -> Option<Type> {
|
||||
match self {
|
||||
Type::Any => Some(Type::Any),
|
||||
Type::Never => todo!("attribute lookup on Never type"),
|
||||
Type::Unknown => Some(Type::Unknown),
|
||||
Type::Unbound => todo!("attribute lookup on Unbound type"),
|
||||
Type::None => todo!("attribute lookup on None type"),
|
||||
Type::Function(_) => todo!("attribute lookup on Function type"),
|
||||
Type::Module(module) => module.member(context, name),
|
||||
Type::Class(class) => class.class_member(context, name),
|
||||
Type::Instance(_) => {
|
||||
// TODO MRO? get_own_instance_member, get_instance_member
|
||||
todo!("attribute lookup on Instance type")
|
||||
}
|
||||
Type::Union(union_id) => {
|
||||
let _union = union_id.lookup(context);
|
||||
// TODO perform the get_member on each type in the union
|
||||
// TODO return the union of those results
|
||||
// TODO if any of those results is `None` then include Unknown in the result union
|
||||
todo!("attribute lookup on Union type")
|
||||
}
|
||||
Type::Intersection(_) => {
|
||||
// TODO perform the get_member on each type in the intersection
|
||||
// TODO return the intersection of those results
|
||||
todo!("attribute lookup on Intersection type")
|
||||
}
|
||||
Type::IntLiteral(_) => {
|
||||
// TODO raise error
|
||||
Some(Type::Unknown)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// ID that uniquely identifies a type in a program.
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct TypeId<L> {
|
||||
/// The scope in which this type is defined or was created.
|
||||
scope: ScopeId,
|
||||
/// The type's local ID in its scope.
|
||||
scoped: L,
|
||||
}
|
||||
|
||||
impl<Id> TypeId<Id>
|
||||
where
|
||||
Id: Copy,
|
||||
{
|
||||
pub fn scope(&self) -> ScopeId {
|
||||
self.scope
|
||||
}
|
||||
|
||||
pub fn scoped_id(&self) -> Id {
|
||||
self.scoped
|
||||
}
|
||||
|
||||
/// Resolves the type ID to the actual type.
|
||||
pub(crate) fn lookup<'a>(self, context: &'a TypingContext) -> &'a Id::Ty
|
||||
where
|
||||
Id: ScopedTypeId,
|
||||
{
|
||||
let types = context.types(self.scope);
|
||||
self.scoped.lookup_scoped(types)
|
||||
}
|
||||
}
|
||||
|
||||
/// ID that uniquely identifies a type in a scope.
|
||||
pub(crate) trait ScopedTypeId {
|
||||
/// The type that this ID points to.
|
||||
type Ty;
|
||||
|
||||
/// Looks up the type in `index`.
|
||||
///
|
||||
/// ## Panics
|
||||
/// May panic if this type is from another scope than `index`, or might just return an invalid type.
|
||||
fn lookup_scoped(self, index: &TypeInference) -> &Self::Ty;
|
||||
}
|
||||
|
||||
/// ID uniquely identifying a function type in a `scope`.
|
||||
#[newtype_index]
|
||||
pub struct ScopedFunctionTypeId;
|
||||
|
||||
impl ScopedTypeId for ScopedFunctionTypeId {
|
||||
type Ty = FunctionType;
|
||||
|
||||
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
|
||||
types.function_ty(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub struct FunctionType {
|
||||
/// name of the function at definition
|
||||
name: Name,
|
||||
/// types of all decorators on this function
|
||||
decorators: Vec<Type>,
|
||||
}
|
||||
|
||||
impl FunctionType {
|
||||
fn name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub(crate) fn decorators(&self) -> &[Type] {
|
||||
self.decorators.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
#[newtype_index]
|
||||
pub struct ScopedClassTypeId;
|
||||
|
||||
impl ScopedTypeId for ScopedClassTypeId {
|
||||
type Ty = ClassType;
|
||||
|
||||
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
|
||||
types.class_ty(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl TypeId<ScopedClassTypeId> {
|
||||
/// Returns the class member of this class named `name`.
|
||||
///
|
||||
/// The member resolves to a member of the class itself or any of its bases.
|
||||
fn class_member(self, context: &TypingContext, name: &Name) -> Option<Type> {
|
||||
if let Some(member) = self.own_class_member(context, name) {
|
||||
return Some(member);
|
||||
}
|
||||
|
||||
let class = self.lookup(context);
|
||||
for base in &class.bases {
|
||||
if let Some(member) = base.member(context, name) {
|
||||
return Some(member);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns the inferred type of the class member named `name`.
|
||||
fn own_class_member(self, context: &TypingContext, name: &Name) -> Option<Type> {
|
||||
let class = self.lookup(context);
|
||||
|
||||
let symbols = symbol_table(context.db, class.body_scope);
|
||||
let symbol = symbols.symbol_id_by_name(name)?;
|
||||
let types = context.types(class.body_scope);
|
||||
|
||||
Some(types.symbol_ty(symbol))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub struct ClassType {
|
||||
/// Name of the class at definition
|
||||
name: Name,
|
||||
|
||||
/// Types of all class bases
|
||||
bases: Vec<Type>,
|
||||
|
||||
body_scope: ScopeId,
|
||||
}
|
||||
|
||||
impl ClassType {
|
||||
fn name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub(super) fn bases(&self) -> &[Type] {
|
||||
self.bases.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
#[newtype_index]
|
||||
pub struct ScopedUnionTypeId;
|
||||
|
||||
impl ScopedTypeId for ScopedUnionTypeId {
|
||||
type Ty = UnionType;
|
||||
|
||||
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
|
||||
types.union_ty(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub struct UnionType {
|
||||
// the union type includes values in any of these types
|
||||
elements: FxIndexSet<Type>,
|
||||
}
|
||||
|
||||
struct UnionTypeBuilder<'a> {
|
||||
elements: FxIndexSet<Type>,
|
||||
context: &'a TypingContext<'a>,
|
||||
}
|
||||
|
||||
impl<'a> UnionTypeBuilder<'a> {
|
||||
fn new(context: &'a TypingContext<'a>) -> Self {
|
||||
Self {
|
||||
context,
|
||||
elements: FxIndexSet::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a type to this union.
|
||||
fn add(mut self, ty: Type) -> Self {
|
||||
match ty {
|
||||
Type::Union(union_id) => {
|
||||
let union = union_id.lookup(self.context);
|
||||
self.elements.extend(&union.elements);
|
||||
}
|
||||
_ => {
|
||||
self.elements.insert(ty);
|
||||
}
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
fn build(self) -> UnionType {
|
||||
UnionType {
|
||||
elements: self.elements,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[newtype_index]
|
||||
pub struct ScopedIntersectionTypeId;
|
||||
|
||||
impl ScopedTypeId for ScopedIntersectionTypeId {
|
||||
type Ty = IntersectionType;
|
||||
|
||||
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
|
||||
types.intersection_ty(self)
|
||||
}
|
||||
}
|
||||
|
||||
// Negation types aren't expressible in annotations, and are most likely to arise from type
|
||||
// narrowing along with intersections (e.g. `if not isinstance(...)`), so we represent them
|
||||
// directly in intersections rather than as a separate type. This sacrifices some efficiency in the
|
||||
// case where a Not appears outside an intersection (unclear when that could even happen, but we'd
|
||||
// have to represent it as a single-element intersection if it did) in exchange for better
|
||||
// efficiency in the within-intersection case.
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct IntersectionType {
|
||||
// the intersection type includes only values in all of these types
|
||||
positive: FxIndexSet<Type>,
|
||||
// the intersection type does not include any value in any of these types
|
||||
negative: FxIndexSet<Type>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct ScopedModuleTypeId;
|
||||
|
||||
impl ScopedTypeId for ScopedModuleTypeId {
|
||||
type Ty = ModuleType;
|
||||
|
||||
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
|
||||
types.module_ty()
|
||||
}
|
||||
}
|
||||
|
||||
impl TypeId<ScopedModuleTypeId> {
|
||||
fn member(self, context: &TypingContext, name: &Name) -> Option<Type> {
|
||||
context.public_symbol_ty(self.scope.file(context.db), name)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub struct ModuleType {
|
||||
file: VfsFile,
|
||||
}
|
||||
|
||||
/// Context in which to resolve types.
|
||||
///
|
||||
/// This abstraction is necessary to support a uniform API that can be used
|
||||
/// while in the process of building the type inference structure for a scope
|
||||
/// but also when all types should be resolved by querying the db.
|
||||
pub struct TypingContext<'a> {
|
||||
db: &'a dyn Db,
|
||||
|
||||
/// The Local type inference scope that is in the process of being built.
|
||||
///
|
||||
/// Bypass the `db` when resolving the types for this scope.
|
||||
local: Option<(ScopeId, &'a TypeInference)>,
|
||||
}
|
||||
|
||||
impl<'a> TypingContext<'a> {
|
||||
/// Creates a context that resolves all types by querying the db.
|
||||
#[allow(unused)]
|
||||
pub(super) fn global(db: &'a dyn Db) -> Self {
|
||||
Self { db, local: None }
|
||||
}
|
||||
|
||||
/// Creates a context that by-passes the `db` when resolving types from `scope_id` and instead uses `types`.
|
||||
fn scoped(db: &'a dyn Db, scope_id: ScopeId, types: &'a TypeInference) -> Self {
|
||||
Self {
|
||||
db,
|
||||
local: Some((scope_id, types)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [`TypeInference`] results (not guaranteed to be complete) for `scope_id`.
|
||||
fn types(&self, scope_id: ScopeId) -> &'a TypeInference {
|
||||
if let Some((scope, local_types)) = self.local {
|
||||
if scope == scope_id {
|
||||
return local_types;
|
||||
}
|
||||
}
|
||||
|
||||
infer_types(self.db, scope_id)
|
||||
}
|
||||
|
||||
fn module_ty(&self, file: VfsFile) -> Type {
|
||||
let scope = root_scope(self.db, file);
|
||||
|
||||
Type::Module(TypeId {
|
||||
scope,
|
||||
scoped: ScopedModuleTypeId,
|
||||
})
|
||||
}
|
||||
|
||||
/// Resolves the public type of a symbol named `name` defined in `file`.
|
||||
///
|
||||
/// This function calls [`public_symbol_ty`] if the local scope isn't the module scope of `file`.
|
||||
/// It otherwise tries to resolve the symbol type locally.
|
||||
fn public_symbol_ty(&self, file: VfsFile, name: &Name) -> Option<Type> {
|
||||
let symbol = public_symbol(self.db, file, name)?;
|
||||
|
||||
if let Some((scope, local_types)) = self.local {
|
||||
if scope.file_scope_id(self.db) == FileScopeId::root() && scope.file(self.db) == file {
|
||||
return Some(local_types.symbol_ty(symbol.scoped_symbol_id(self.db)));
|
||||
}
|
||||
}
|
||||
|
||||
Some(public_symbol_ty(self.db, symbol))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use ruff_db::file_system::FileSystemPathBuf;
|
||||
use ruff_db::parsed::parsed_module;
|
||||
use ruff_db::vfs::system_path_to_file;
|
||||
|
||||
use crate::db::tests::{
|
||||
assert_will_not_run_function_query, assert_will_run_function_query, TestDb,
|
||||
};
|
||||
use crate::module::resolver::{set_module_resolution_settings, ModuleResolutionSettings};
|
||||
use crate::semantic_index::root_scope;
|
||||
use crate::types::{expression_ty, infer_types, public_symbol_ty_by_name, TypingContext};
|
||||
|
||||
fn setup_db() -> TestDb {
|
||||
let mut db = TestDb::new();
|
||||
set_module_resolution_settings(
|
||||
&mut db,
|
||||
ModuleResolutionSettings {
|
||||
extra_paths: vec![],
|
||||
workspace_root: FileSystemPathBuf::from("/src"),
|
||||
site_packages: None,
|
||||
custom_typeshed: None,
|
||||
},
|
||||
);
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_inference() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_file("/src/a.py", "x = 10")?;
|
||||
let a = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||
|
||||
let parsed = parsed_module(&db, a);
|
||||
|
||||
let statement = parsed.suite().first().unwrap().as_assign_stmt().unwrap();
|
||||
|
||||
let literal_ty = expression_ty(&db, a, &statement.value);
|
||||
|
||||
assert_eq!(
|
||||
format!("{}", literal_ty.display(&TypingContext::global(&db))),
|
||||
"Literal[10]"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dependency_public_symbol_type_change() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.memory_file_system().write_files([
|
||||
("/src/a.py", "from foo import x"),
|
||||
("/src/foo.py", "x = 10\ndef foo(): ..."),
|
||||
])?;
|
||||
|
||||
let a = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
x_ty.display(&TypingContext::global(&db)).to_string(),
|
||||
"Literal[10]"
|
||||
);
|
||||
|
||||
// Change `x` to a different value
|
||||
db.memory_file_system()
|
||||
.write_file("/src/foo.py", "x = 20\ndef foo(): ...")?;
|
||||
|
||||
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
|
||||
foo.touch(&mut db);
|
||||
|
||||
let a = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||
|
||||
db.clear_salsa_events();
|
||||
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
x_ty_2.display(&TypingContext::global(&db)).to_string(),
|
||||
"Literal[20]"
|
||||
);
|
||||
|
||||
let events = db.take_salsa_events();
|
||||
|
||||
let a_root_scope = root_scope(&db, a);
|
||||
assert_will_run_function_query::<infer_types, _, _>(
|
||||
&db,
|
||||
|ty| &ty.function,
|
||||
a_root_scope,
|
||||
&events,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dependency_non_public_symbol_change() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.memory_file_system().write_files([
|
||||
("/src/a.py", "from foo import x"),
|
||||
("/src/foo.py", "x = 10\ndef foo(): y = 1"),
|
||||
])?;
|
||||
|
||||
let a = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
x_ty.display(&TypingContext::global(&db)).to_string(),
|
||||
"Literal[10]"
|
||||
);
|
||||
|
||||
db.memory_file_system()
|
||||
.write_file("/src/foo.py", "x = 10\ndef foo(): pass")?;
|
||||
|
||||
let a = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
|
||||
|
||||
foo.touch(&mut db);
|
||||
|
||||
db.clear_salsa_events();
|
||||
|
||||
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
x_ty_2.display(&TypingContext::global(&db)).to_string(),
|
||||
"Literal[10]"
|
||||
);
|
||||
|
||||
let events = db.take_salsa_events();
|
||||
|
||||
let a_root_scope = root_scope(&db, a);
|
||||
|
||||
assert_will_not_run_function_query::<infer_types, _, _>(
|
||||
&db,
|
||||
|ty| &ty.function,
|
||||
a_root_scope,
|
||||
&events,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dependency_unrelated_public_symbol() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.memory_file_system().write_files([
|
||||
("/src/a.py", "from foo import x"),
|
||||
("/src/foo.py", "x = 10\ny = 20"),
|
||||
])?;
|
||||
|
||||
let a = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
x_ty.display(&TypingContext::global(&db)).to_string(),
|
||||
"Literal[10]"
|
||||
);
|
||||
|
||||
db.memory_file_system()
|
||||
.write_file("/src/foo.py", "x = 10\ny = 30")?;
|
||||
|
||||
let a = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
|
||||
|
||||
foo.touch(&mut db);
|
||||
|
||||
db.clear_salsa_events();
|
||||
|
||||
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
x_ty_2.display(&TypingContext::global(&db)).to_string(),
|
||||
"Literal[10]"
|
||||
);
|
||||
|
||||
let events = db.take_salsa_events();
|
||||
|
||||
let a_root_scope = root_scope(&db, a);
|
||||
assert_will_not_run_function_query::<infer_types, _, _>(
|
||||
&db,
|
||||
|ty| &ty.function,
|
||||
a_root_scope,
|
||||
&events,
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
175
crates/red_knot_python_semantic/src/types/display.rs
Normal file
175
crates/red_knot_python_semantic/src/types/display.rs
Normal file
|
@ -0,0 +1,175 @@
|
|||
//! Display implementations for types.
|
||||
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
||||
use crate::types::{IntersectionType, Type, TypingContext, UnionType};
|
||||
|
||||
impl Type {
|
||||
pub fn display<'a>(&'a self, context: &'a TypingContext) -> DisplayType<'a> {
|
||||
DisplayType { ty: self, context }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct DisplayType<'a> {
|
||||
ty: &'a Type,
|
||||
context: &'a TypingContext<'a>,
|
||||
}
|
||||
|
||||
impl Display for DisplayType<'_> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self.ty {
|
||||
Type::Any => f.write_str("Any"),
|
||||
Type::Never => f.write_str("Never"),
|
||||
Type::Unknown => f.write_str("Unknown"),
|
||||
Type::Unbound => f.write_str("Unbound"),
|
||||
Type::None => f.write_str("None"),
|
||||
Type::Module(module_id) => {
|
||||
write!(
|
||||
f,
|
||||
"<module '{:?}'>",
|
||||
module_id
|
||||
.scope
|
||||
.file(self.context.db)
|
||||
.path(self.context.db.upcast())
|
||||
)
|
||||
}
|
||||
// TODO functions and classes should display using a fully qualified name
|
||||
Type::Class(class_id) => {
|
||||
let class = class_id.lookup(self.context);
|
||||
|
||||
f.write_str("Literal[")?;
|
||||
f.write_str(class.name())?;
|
||||
f.write_str("]")
|
||||
}
|
||||
Type::Instance(class_id) => {
|
||||
let class = class_id.lookup(self.context);
|
||||
f.write_str(class.name())
|
||||
}
|
||||
Type::Function(function_id) => {
|
||||
let function = function_id.lookup(self.context);
|
||||
f.write_str(function.name())
|
||||
}
|
||||
Type::Union(union_id) => {
|
||||
let union = union_id.lookup(self.context);
|
||||
|
||||
union.display(self.context).fmt(f)
|
||||
}
|
||||
Type::Intersection(intersection_id) => {
|
||||
let intersection = intersection_id.lookup(self.context);
|
||||
|
||||
intersection.display(self.context).fmt(f)
|
||||
}
|
||||
Type::IntLiteral(n) => write!(f, "Literal[{n}]"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for DisplayType<'_> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
std::fmt::Display::fmt(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl UnionType {
|
||||
fn display<'a>(&'a self, context: &'a TypingContext<'a>) -> DisplayUnionType<'a> {
|
||||
DisplayUnionType { context, ty: self }
|
||||
}
|
||||
}
|
||||
|
||||
struct DisplayUnionType<'a> {
|
||||
ty: &'a UnionType,
|
||||
context: &'a TypingContext<'a>,
|
||||
}
|
||||
|
||||
impl Display for DisplayUnionType<'_> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
let union = self.ty;
|
||||
|
||||
let (int_literals, other_types): (Vec<Type>, Vec<Type>) = union
|
||||
.elements
|
||||
.iter()
|
||||
.copied()
|
||||
.partition(|ty| matches!(ty, Type::IntLiteral(_)));
|
||||
|
||||
let mut first = true;
|
||||
if !int_literals.is_empty() {
|
||||
f.write_str("Literal[")?;
|
||||
let mut nums: Vec<_> = int_literals
|
||||
.into_iter()
|
||||
.filter_map(|ty| {
|
||||
if let Type::IntLiteral(n) = ty {
|
||||
Some(n)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
nums.sort_unstable();
|
||||
for num in nums {
|
||||
if !first {
|
||||
f.write_str(", ")?;
|
||||
}
|
||||
write!(f, "{num}")?;
|
||||
first = false;
|
||||
}
|
||||
f.write_str("]")?;
|
||||
}
|
||||
|
||||
for ty in other_types {
|
||||
if !first {
|
||||
f.write_str(" | ")?;
|
||||
};
|
||||
first = false;
|
||||
write!(f, "{}", ty.display(self.context))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for DisplayUnionType<'_> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
std::fmt::Display::fmt(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntersectionType {
|
||||
fn display<'a>(&'a self, context: &'a TypingContext<'a>) -> DisplayIntersectionType<'a> {
|
||||
DisplayIntersectionType { ty: self, context }
|
||||
}
|
||||
}
|
||||
|
||||
struct DisplayIntersectionType<'a> {
|
||||
ty: &'a IntersectionType,
|
||||
context: &'a TypingContext<'a>,
|
||||
}
|
||||
|
||||
impl Display for DisplayIntersectionType<'_> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
let mut first = true;
|
||||
for (neg, ty) in self
|
||||
.ty
|
||||
.positive
|
||||
.iter()
|
||||
.map(|ty| (false, ty))
|
||||
.chain(self.ty.negative.iter().map(|ty| (true, ty)))
|
||||
{
|
||||
if !first {
|
||||
f.write_str(" & ")?;
|
||||
};
|
||||
first = false;
|
||||
if neg {
|
||||
f.write_str("~")?;
|
||||
};
|
||||
write!(f, "{}", ty.display(self.context))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for DisplayIntersectionType<'_> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
std::fmt::Display::fmt(self, f)
|
||||
}
|
||||
}
|
941
crates/red_knot_python_semantic/src/types/infer.rs
Normal file
941
crates/red_knot_python_semantic/src/types/infer.rs
Normal file
|
@ -0,0 +1,941 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use ruff_db::vfs::VfsFile;
|
||||
use ruff_index::IndexVec;
|
||||
use ruff_python_ast as ast;
|
||||
use ruff_python_ast::{ExprContext, TypeParams};
|
||||
|
||||
use crate::module::resolver::resolve_module;
|
||||
use crate::module::ModuleName;
|
||||
use crate::name::Name;
|
||||
use crate::semantic_index::ast_ids::{ScopeAstIdNode, ScopeExpressionId};
|
||||
use crate::semantic_index::definition::{Definition, ImportDefinition, ImportFromDefinition};
|
||||
use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable};
|
||||
use crate::semantic_index::{symbol_table, ChildrenIter, SemanticIndex};
|
||||
use crate::types::{
|
||||
ClassType, FunctionType, IntersectionType, ModuleType, ScopedClassTypeId, ScopedFunctionTypeId,
|
||||
ScopedIntersectionTypeId, ScopedUnionTypeId, Type, TypeId, TypingContext, UnionType,
|
||||
UnionTypeBuilder,
|
||||
};
|
||||
use crate::Db;
|
||||
|
||||
/// The inferred types for a single scope.
|
||||
#[derive(Debug, Eq, PartialEq, Default, Clone)]
|
||||
pub(crate) struct TypeInference {
|
||||
/// The type of the module if the scope is a module scope.
|
||||
module_type: Option<ModuleType>,
|
||||
|
||||
/// The types of the defined classes in this scope.
|
||||
class_types: IndexVec<ScopedClassTypeId, ClassType>,
|
||||
|
||||
/// The types of the defined functions in this scope.
|
||||
function_types: IndexVec<ScopedFunctionTypeId, FunctionType>,
|
||||
|
||||
union_types: IndexVec<ScopedUnionTypeId, UnionType>,
|
||||
intersection_types: IndexVec<ScopedIntersectionTypeId, IntersectionType>,
|
||||
|
||||
/// The types of every expression in this scope.
|
||||
expression_tys: IndexVec<ScopeExpressionId, Type>,
|
||||
|
||||
/// The public types of every symbol in this scope.
|
||||
symbol_tys: IndexVec<ScopedSymbolId, Type>,
|
||||
}
|
||||
|
||||
impl TypeInference {
|
||||
#[allow(unused)]
|
||||
pub(super) fn expression_ty(&self, expression: ScopeExpressionId) -> Type {
|
||||
self.expression_tys[expression]
|
||||
}
|
||||
|
||||
pub(super) fn symbol_ty(&self, symbol: ScopedSymbolId) -> Type {
|
||||
self.symbol_tys[symbol]
|
||||
}
|
||||
|
||||
pub(super) fn module_ty(&self) -> &ModuleType {
|
||||
self.module_type.as_ref().unwrap()
|
||||
}
|
||||
|
||||
pub(super) fn class_ty(&self, id: ScopedClassTypeId) -> &ClassType {
|
||||
&self.class_types[id]
|
||||
}
|
||||
|
||||
pub(super) fn function_ty(&self, id: ScopedFunctionTypeId) -> &FunctionType {
|
||||
&self.function_types[id]
|
||||
}
|
||||
|
||||
pub(super) fn union_ty(&self, id: ScopedUnionTypeId) -> &UnionType {
|
||||
&self.union_types[id]
|
||||
}
|
||||
|
||||
pub(super) fn intersection_ty(&self, id: ScopedIntersectionTypeId) -> &IntersectionType {
|
||||
&self.intersection_types[id]
|
||||
}
|
||||
|
||||
fn shrink_to_fit(&mut self) {
|
||||
self.class_types.shrink_to_fit();
|
||||
self.function_types.shrink_to_fit();
|
||||
self.union_types.shrink_to_fit();
|
||||
self.intersection_types.shrink_to_fit();
|
||||
|
||||
self.expression_tys.shrink_to_fit();
|
||||
self.symbol_tys.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder to infer all types in a [`ScopeId`].
|
||||
pub(super) struct TypeInferenceBuilder<'a> {
|
||||
db: &'a dyn Db,
|
||||
|
||||
// Cached lookups
|
||||
index: &'a SemanticIndex,
|
||||
scope: ScopeId,
|
||||
file_scope_id: FileScopeId,
|
||||
file_id: VfsFile,
|
||||
symbol_table: Arc<SymbolTable>,
|
||||
|
||||
/// The type inference results
|
||||
types: TypeInference,
|
||||
definition_tys: FxHashMap<Definition, Type>,
|
||||
children_scopes: ChildrenIter<'a>,
|
||||
}
|
||||
|
||||
impl<'a> TypeInferenceBuilder<'a> {
|
||||
/// Creates a new builder for inferring the types of `scope`.
|
||||
pub(super) fn new(db: &'a dyn Db, scope: ScopeId, index: &'a SemanticIndex) -> Self {
|
||||
let file_scope_id = scope.file_scope_id(db);
|
||||
let file = scope.file(db);
|
||||
let children_scopes = index.child_scopes(file_scope_id);
|
||||
let symbol_table = index.symbol_table(file_scope_id);
|
||||
|
||||
Self {
|
||||
index,
|
||||
file_scope_id,
|
||||
file_id: file,
|
||||
scope,
|
||||
symbol_table,
|
||||
|
||||
db,
|
||||
types: TypeInference::default(),
|
||||
definition_tys: FxHashMap::default(),
|
||||
children_scopes,
|
||||
}
|
||||
}
|
||||
|
||||
/// Infers the types of a `module`.
|
||||
pub(super) fn infer_module(&mut self, module: &ast::ModModule) {
|
||||
self.infer_body(&module.body);
|
||||
}
|
||||
|
||||
pub(super) fn infer_class_type_params(&mut self, class: &ast::StmtClassDef) {
|
||||
if let Some(type_params) = class.type_params.as_deref() {
|
||||
self.infer_type_parameters(type_params);
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn infer_class_body(&mut self, class: &ast::StmtClassDef) {
|
||||
self.infer_body(&class.body);
|
||||
}
|
||||
|
||||
pub(super) fn infer_function_type_params(&mut self, function: &ast::StmtFunctionDef) {
|
||||
if let Some(type_params) = function.type_params.as_deref() {
|
||||
self.infer_type_parameters(type_params);
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn infer_function_body(&mut self, function: &ast::StmtFunctionDef) {
|
||||
self.infer_body(&function.body);
|
||||
}
|
||||
|
||||
fn infer_body(&mut self, suite: &[ast::Stmt]) {
|
||||
for statement in suite {
|
||||
self.infer_statement(statement);
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_statement(&mut self, statement: &ast::Stmt) {
|
||||
match statement {
|
||||
ast::Stmt::FunctionDef(function) => self.infer_function_definition_statement(function),
|
||||
ast::Stmt::ClassDef(class) => self.infer_class_definition_statement(class),
|
||||
ast::Stmt::Expr(ast::StmtExpr { range: _, value }) => {
|
||||
self.infer_expression(value);
|
||||
}
|
||||
ast::Stmt::If(if_statement) => self.infer_if_statement(if_statement),
|
||||
ast::Stmt::Assign(assign) => self.infer_assignment_statement(assign),
|
||||
ast::Stmt::AnnAssign(assign) => self.infer_annotated_assignment_statement(assign),
|
||||
ast::Stmt::For(for_statement) => self.infer_for_statement(for_statement),
|
||||
ast::Stmt::Import(import) => self.infer_import_statement(import),
|
||||
ast::Stmt::ImportFrom(import) => self.infer_import_from_statement(import),
|
||||
ast::Stmt::Break(_) | ast::Stmt::Continue(_) | ast::Stmt::Pass(_) => {
|
||||
// No-op
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_function_definition_statement(&mut self, function: &ast::StmtFunctionDef) {
|
||||
let ast::StmtFunctionDef {
|
||||
range: _,
|
||||
is_async: _,
|
||||
name,
|
||||
type_params: _,
|
||||
parameters: _,
|
||||
returns,
|
||||
body: _,
|
||||
decorator_list,
|
||||
} = function;
|
||||
|
||||
let function_id = function.scope_ast_id(self.db, self.file_id, self.file_scope_id);
|
||||
let decorator_tys = decorator_list
|
||||
.iter()
|
||||
.map(|decorator| self.infer_decorator(decorator))
|
||||
.collect();
|
||||
|
||||
// TODO: Infer parameters
|
||||
|
||||
if let Some(return_ty) = returns {
|
||||
self.infer_expression(return_ty);
|
||||
}
|
||||
|
||||
let function_ty = self.function_ty(FunctionType {
|
||||
name: Name::new(&name.id),
|
||||
decorators: decorator_tys,
|
||||
});
|
||||
|
||||
// Skip over the function or type params child scope.
|
||||
let (_, scope) = self.children_scopes.next().unwrap();
|
||||
|
||||
assert!(matches!(
|
||||
scope.kind(),
|
||||
ScopeKind::Function | ScopeKind::Annotation
|
||||
));
|
||||
|
||||
self.definition_tys
|
||||
.insert(Definition::FunctionDef(function_id), function_ty);
|
||||
}
|
||||
|
||||
fn infer_class_definition_statement(&mut self, class: &ast::StmtClassDef) {
|
||||
let ast::StmtClassDef {
|
||||
range: _,
|
||||
name,
|
||||
type_params,
|
||||
decorator_list,
|
||||
arguments,
|
||||
body: _,
|
||||
} = class;
|
||||
|
||||
let class_id = class.scope_ast_id(self.db, self.file_id, self.file_scope_id);
|
||||
|
||||
for decorator in decorator_list {
|
||||
self.infer_decorator(decorator);
|
||||
}
|
||||
|
||||
let bases = arguments
|
||||
.as_deref()
|
||||
.map(|arguments| self.infer_arguments(arguments))
|
||||
.unwrap_or(Vec::new());
|
||||
|
||||
// If the class has type parameters, then the class body scope is the first child scope of the type parameter's scope
|
||||
// Otherwise the next scope must be the class definition scope.
|
||||
let (class_body_scope_id, class_body_scope) = if type_params.is_some() {
|
||||
let (type_params_scope, _) = self.children_scopes.next().unwrap();
|
||||
self.index.child_scopes(type_params_scope).next().unwrap()
|
||||
} else {
|
||||
self.children_scopes.next().unwrap()
|
||||
};
|
||||
|
||||
assert_eq!(class_body_scope.kind(), ScopeKind::Class);
|
||||
|
||||
let class_ty = self.class_ty(ClassType {
|
||||
name: Name::new(name),
|
||||
bases,
|
||||
body_scope: class_body_scope_id.to_scope_id(self.db, self.file_id),
|
||||
});
|
||||
|
||||
self.definition_tys
|
||||
.insert(Definition::ClassDef(class_id), class_ty);
|
||||
}
|
||||
|
||||
fn infer_if_statement(&mut self, if_statement: &ast::StmtIf) {
|
||||
let ast::StmtIf {
|
||||
range: _,
|
||||
test,
|
||||
body,
|
||||
elif_else_clauses,
|
||||
} = if_statement;
|
||||
|
||||
self.infer_expression(test);
|
||||
self.infer_body(body);
|
||||
|
||||
for clause in elif_else_clauses {
|
||||
let ast::ElifElseClause {
|
||||
range: _,
|
||||
test,
|
||||
body,
|
||||
} = clause;
|
||||
|
||||
if let Some(test) = &test {
|
||||
self.infer_expression(test);
|
||||
}
|
||||
|
||||
self.infer_body(body);
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_assignment_statement(&mut self, assignment: &ast::StmtAssign) {
|
||||
let ast::StmtAssign {
|
||||
range: _,
|
||||
targets,
|
||||
value,
|
||||
} = assignment;
|
||||
|
||||
let value_ty = self.infer_expression(value);
|
||||
|
||||
for target in targets {
|
||||
self.infer_expression(target);
|
||||
}
|
||||
|
||||
let assign_id = assignment.scope_ast_id(self.db, self.file_id, self.file_scope_id);
|
||||
|
||||
// TODO: Handle multiple targets.
|
||||
self.definition_tys
|
||||
.insert(Definition::Assignment(assign_id), value_ty);
|
||||
}
|
||||
|
||||
fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) {
|
||||
let ast::StmtAnnAssign {
|
||||
range: _,
|
||||
target,
|
||||
annotation,
|
||||
value,
|
||||
simple: _,
|
||||
} = assignment;
|
||||
|
||||
if let Some(value) = value {
|
||||
let _ = self.infer_expression(value);
|
||||
}
|
||||
|
||||
let annotation_ty = self.infer_expression(annotation);
|
||||
self.infer_expression(target);
|
||||
|
||||
self.definition_tys.insert(
|
||||
Definition::AnnotatedAssignment(assignment.scope_ast_id(
|
||||
self.db,
|
||||
self.file_id,
|
||||
self.file_scope_id,
|
||||
)),
|
||||
annotation_ty,
|
||||
);
|
||||
}
|
||||
|
||||
fn infer_for_statement(&mut self, for_statement: &ast::StmtFor) {
|
||||
let ast::StmtFor {
|
||||
range: _,
|
||||
target,
|
||||
iter,
|
||||
body,
|
||||
orelse,
|
||||
is_async: _,
|
||||
} = for_statement;
|
||||
|
||||
self.infer_expression(iter);
|
||||
self.infer_expression(target);
|
||||
self.infer_body(body);
|
||||
self.infer_body(orelse);
|
||||
}
|
||||
|
||||
fn infer_import_statement(&mut self, import: &ast::StmtImport) {
|
||||
let ast::StmtImport { range: _, names } = import;
|
||||
|
||||
let import_id = import.scope_ast_id(self.db, self.file_id, self.file_scope_id);
|
||||
|
||||
for (i, alias) in names.iter().enumerate() {
|
||||
let ast::Alias {
|
||||
range: _,
|
||||
name,
|
||||
asname: _,
|
||||
} = alias;
|
||||
|
||||
let module_name = ModuleName::new(&name.id);
|
||||
let module = module_name.and_then(|name| resolve_module(self.db, name));
|
||||
let module_ty = module
|
||||
.map(|module| self.typing_context().module_ty(module.file()))
|
||||
.unwrap_or(Type::Unknown);
|
||||
|
||||
self.definition_tys.insert(
|
||||
Definition::Import(ImportDefinition {
|
||||
import_id,
|
||||
alias: u32::try_from(i).unwrap(),
|
||||
}),
|
||||
module_ty,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_import_from_statement(&mut self, import: &ast::StmtImportFrom) {
|
||||
let ast::StmtImportFrom {
|
||||
range: _,
|
||||
module,
|
||||
names,
|
||||
level: _,
|
||||
} = import;
|
||||
|
||||
let import_id = import.scope_ast_id(self.db, self.file_id, self.file_scope_id);
|
||||
let module_name = ModuleName::new(module.as_deref().expect("Support relative imports"));
|
||||
|
||||
let module = module_name.and_then(|module_name| resolve_module(self.db, module_name));
|
||||
let module_ty = module
|
||||
.map(|module| self.typing_context().module_ty(module.file()))
|
||||
.unwrap_or(Type::Unknown);
|
||||
|
||||
for (i, alias) in names.iter().enumerate() {
|
||||
let ast::Alias {
|
||||
range: _,
|
||||
name,
|
||||
asname: _,
|
||||
} = alias;
|
||||
|
||||
let ty = module_ty
|
||||
.member(&self.typing_context(), &Name::new(&name.id))
|
||||
.unwrap_or(Type::Unknown);
|
||||
|
||||
self.definition_tys.insert(
|
||||
Definition::ImportFrom(ImportFromDefinition {
|
||||
import_id,
|
||||
name: u32::try_from(i).unwrap(),
|
||||
}),
|
||||
ty,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_decorator(&mut self, decorator: &ast::Decorator) -> Type {
|
||||
let ast::Decorator {
|
||||
range: _,
|
||||
expression,
|
||||
} = decorator;
|
||||
|
||||
self.infer_expression(expression)
|
||||
}
|
||||
|
||||
fn infer_arguments(&mut self, arguments: &ast::Arguments) -> Vec<Type> {
|
||||
let mut types = Vec::with_capacity(
|
||||
arguments
|
||||
.args
|
||||
.len()
|
||||
.saturating_add(arguments.keywords.len()),
|
||||
);
|
||||
|
||||
types.extend(arguments.args.iter().map(|arg| self.infer_expression(arg)));
|
||||
|
||||
types.extend(arguments.keywords.iter().map(
|
||||
|ast::Keyword {
|
||||
range: _,
|
||||
arg: _,
|
||||
value,
|
||||
}| self.infer_expression(value),
|
||||
));
|
||||
|
||||
types
|
||||
}
|
||||
|
||||
fn infer_expression(&mut self, expression: &ast::Expr) -> Type {
|
||||
let ty = match expression {
|
||||
ast::Expr::NoneLiteral(ast::ExprNoneLiteral { range: _ }) => Type::None,
|
||||
ast::Expr::NumberLiteral(literal) => self.infer_number_literal_expression(literal),
|
||||
ast::Expr::Name(name) => self.infer_name_expression(name),
|
||||
ast::Expr::Attribute(attribute) => self.infer_attribute_expression(attribute),
|
||||
ast::Expr::BinOp(binary) => self.infer_binary_expression(binary),
|
||||
ast::Expr::Named(named) => self.infer_named_expression(named),
|
||||
ast::Expr::If(if_expression) => self.infer_if_expression(if_expression),
|
||||
|
||||
_ => todo!("expression type resolution for {:?}", expression),
|
||||
};
|
||||
|
||||
self.types.expression_tys.push(ty);
|
||||
|
||||
ty
|
||||
}
|
||||
|
||||
#[allow(clippy::unused_self)]
|
||||
fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type {
|
||||
let ast::ExprNumberLiteral { range: _, value } = literal;
|
||||
|
||||
match value {
|
||||
ast::Number::Int(n) => {
|
||||
// TODO support big int literals
|
||||
n.as_i64().map(Type::IntLiteral).unwrap_or(Type::Unknown)
|
||||
}
|
||||
// TODO builtins.float or builtins.complex
|
||||
_ => Type::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_named_expression(&mut self, named: &ast::ExprNamed) -> Type {
|
||||
let ast::ExprNamed {
|
||||
range: _,
|
||||
target,
|
||||
value,
|
||||
} = named;
|
||||
|
||||
let value_ty = self.infer_expression(value);
|
||||
self.infer_expression(target);
|
||||
|
||||
self.definition_tys.insert(
|
||||
Definition::NamedExpr(named.scope_ast_id(self.db, self.file_id, self.file_scope_id)),
|
||||
value_ty,
|
||||
);
|
||||
|
||||
value_ty
|
||||
}
|
||||
|
||||
fn infer_if_expression(&mut self, if_expression: &ast::ExprIf) -> Type {
|
||||
let ast::ExprIf {
|
||||
range: _,
|
||||
test,
|
||||
body,
|
||||
orelse,
|
||||
} = if_expression;
|
||||
|
||||
self.infer_expression(test);
|
||||
|
||||
// TODO detect statically known truthy or falsy test
|
||||
let body_ty = self.infer_expression(body);
|
||||
let orelse_ty = self.infer_expression(orelse);
|
||||
|
||||
let union = UnionTypeBuilder::new(&self.typing_context())
|
||||
.add(body_ty)
|
||||
.add(orelse_ty)
|
||||
.build();
|
||||
|
||||
self.union_ty(union)
|
||||
}
|
||||
|
||||
fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type {
|
||||
let ast::ExprName { range: _, id, ctx } = name;
|
||||
|
||||
match ctx {
|
||||
ExprContext::Load => {
|
||||
if let Some(symbol_id) = self
|
||||
.index
|
||||
.symbol_table(self.file_scope_id)
|
||||
.symbol_id_by_name(id)
|
||||
{
|
||||
self.local_definition_ty(symbol_id)
|
||||
} else {
|
||||
let ancestors = self.index.ancestor_scopes(self.file_scope_id).skip(1);
|
||||
|
||||
for (ancestor_id, _) in ancestors {
|
||||
// TODO: Skip over class scopes unless the they are a immediately-nested type param scope.
|
||||
// TODO: Support built-ins
|
||||
|
||||
let symbol_table =
|
||||
symbol_table(self.db, ancestor_id.to_scope_id(self.db, self.file_id));
|
||||
|
||||
if let Some(_symbol_id) = symbol_table.symbol_id_by_name(id) {
|
||||
todo!("Return type for symbol from outer scope");
|
||||
}
|
||||
}
|
||||
Type::Unknown
|
||||
}
|
||||
}
|
||||
ExprContext::Del => Type::None,
|
||||
ExprContext::Invalid => Type::Unknown,
|
||||
ExprContext::Store => Type::None,
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_attribute_expression(&mut self, attribute: &ast::ExprAttribute) -> Type {
|
||||
let ast::ExprAttribute {
|
||||
value,
|
||||
attr,
|
||||
range: _,
|
||||
ctx,
|
||||
} = attribute;
|
||||
|
||||
let value_ty = self.infer_expression(value);
|
||||
let member_ty = value_ty
|
||||
.member(&self.typing_context(), &Name::new(&attr.id))
|
||||
.unwrap_or(Type::Unknown);
|
||||
|
||||
match ctx {
|
||||
ExprContext::Load => member_ty,
|
||||
ExprContext::Store | ExprContext::Del => Type::None,
|
||||
ExprContext::Invalid => Type::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_binary_expression(&mut self, binary: &ast::ExprBinOp) -> Type {
|
||||
let ast::ExprBinOp {
|
||||
left,
|
||||
op,
|
||||
right,
|
||||
range: _,
|
||||
} = binary;
|
||||
|
||||
let left_ty = self.infer_expression(left);
|
||||
let right_ty = self.infer_expression(right);
|
||||
|
||||
match left_ty {
|
||||
Type::Any => Type::Any,
|
||||
Type::Unknown => Type::Unknown,
|
||||
Type::IntLiteral(n) => {
|
||||
match right_ty {
|
||||
Type::IntLiteral(m) => {
|
||||
match op {
|
||||
ast::Operator::Add => n
|
||||
.checked_add(m)
|
||||
.map(Type::IntLiteral)
|
||||
// TODO builtins.int
|
||||
.unwrap_or(Type::Unknown),
|
||||
ast::Operator::Sub => n
|
||||
.checked_sub(m)
|
||||
.map(Type::IntLiteral)
|
||||
// TODO builtins.int
|
||||
.unwrap_or(Type::Unknown),
|
||||
ast::Operator::Mult => n
|
||||
.checked_mul(m)
|
||||
.map(Type::IntLiteral)
|
||||
// TODO builtins.int
|
||||
.unwrap_or(Type::Unknown),
|
||||
ast::Operator::Div => n
|
||||
.checked_div(m)
|
||||
.map(Type::IntLiteral)
|
||||
// TODO builtins.int
|
||||
.unwrap_or(Type::Unknown),
|
||||
ast::Operator::Mod => n
|
||||
.checked_rem(m)
|
||||
.map(Type::IntLiteral)
|
||||
// TODO division by zero error
|
||||
.unwrap_or(Type::Unknown),
|
||||
_ => todo!("complete binop op support for IntLiteral"),
|
||||
}
|
||||
}
|
||||
_ => todo!("complete binop right_ty support for IntLiteral"),
|
||||
}
|
||||
}
|
||||
_ => todo!("complete binop support"),
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_type_parameters(&mut self, _type_parameters: &TypeParams) {
|
||||
todo!("Infer type parameters")
|
||||
}
|
||||
|
||||
pub(super) fn finish(mut self) -> TypeInference {
|
||||
let symbol_tys: IndexVec<_, _> = self
|
||||
.index
|
||||
.symbol_table(self.file_scope_id)
|
||||
.symbol_ids()
|
||||
.map(|symbol| self.local_definition_ty(symbol))
|
||||
.collect();
|
||||
|
||||
self.types.symbol_tys = symbol_tys;
|
||||
self.types.shrink_to_fit();
|
||||
self.types
|
||||
}
|
||||
|
||||
fn union_ty(&mut self, ty: UnionType) -> Type {
|
||||
Type::Union(TypeId {
|
||||
scope: self.scope,
|
||||
scoped: self.types.union_types.push(ty),
|
||||
})
|
||||
}
|
||||
|
||||
fn function_ty(&mut self, ty: FunctionType) -> Type {
|
||||
Type::Function(TypeId {
|
||||
scope: self.scope,
|
||||
scoped: self.types.function_types.push(ty),
|
||||
})
|
||||
}
|
||||
|
||||
fn class_ty(&mut self, ty: ClassType) -> Type {
|
||||
Type::Class(TypeId {
|
||||
scope: self.scope,
|
||||
scoped: self.types.class_types.push(ty),
|
||||
})
|
||||
}
|
||||
|
||||
fn typing_context(&self) -> TypingContext {
|
||||
TypingContext::scoped(self.db, self.scope, &self.types)
|
||||
}
|
||||
|
||||
fn local_definition_ty(&mut self, symbol: ScopedSymbolId) -> Type {
|
||||
let symbol = self.symbol_table.symbol(symbol);
|
||||
let mut definitions = symbol
|
||||
.definitions()
|
||||
.iter()
|
||||
.filter_map(|definition| self.definition_tys.get(definition).copied());
|
||||
|
||||
let Some(first) = definitions.next() else {
|
||||
return Type::Unbound;
|
||||
};
|
||||
|
||||
if let Some(second) = definitions.next() {
|
||||
let context = self.typing_context();
|
||||
let mut builder = UnionTypeBuilder::new(&context);
|
||||
builder = builder.add(first).add(second);
|
||||
|
||||
for variant in definitions {
|
||||
builder = builder.add(variant);
|
||||
}
|
||||
|
||||
self.union_ty(builder.build())
|
||||
} else {
|
||||
first
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use ruff_db::file_system::FileSystemPathBuf;
|
||||
use ruff_db::vfs::system_path_to_file;
|
||||
|
||||
use crate::db::tests::TestDb;
|
||||
use crate::module::resolver::{set_module_resolution_settings, ModuleResolutionSettings};
|
||||
use crate::name::Name;
|
||||
use crate::types::{public_symbol_ty_by_name, Type, TypingContext};
|
||||
|
||||
fn setup_db() -> TestDb {
|
||||
let mut db = TestDb::new();
|
||||
|
||||
set_module_resolution_settings(
|
||||
&mut db,
|
||||
ModuleResolutionSettings {
|
||||
extra_paths: Vec::new(),
|
||||
workspace_root: FileSystemPathBuf::from("/src"),
|
||||
site_packages: None,
|
||||
custom_typeshed: None,
|
||||
},
|
||||
);
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
fn assert_public_ty(db: &TestDb, file_name: &str, symbol_name: &str, expected: &str) {
|
||||
let file = system_path_to_file(db, file_name).expect("Expected file to exist.");
|
||||
|
||||
let ty = public_symbol_ty_by_name(db, file, symbol_name).unwrap_or(Type::Unknown);
|
||||
assert_eq!(ty.display(&TypingContext::global(db)).to_string(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn follow_import_to_class() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_files([
|
||||
("src/a.py", "from b import C as D; E = D"),
|
||||
("src/b.py", "class C: pass"),
|
||||
])?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "E", "Literal[C]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_base_class_by_name() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_file(
|
||||
"src/mod.py",
|
||||
r#"
|
||||
class Base:
|
||||
pass
|
||||
|
||||
class Sub(Base):
|
||||
pass"#,
|
||||
)?;
|
||||
|
||||
let mod_file = system_path_to_file(&db, "src/mod.py").expect("Expected file to exist.");
|
||||
let ty = public_symbol_ty_by_name(&db, mod_file, "Sub").expect("Symbol type to exist");
|
||||
|
||||
let Type::Class(class_id) = ty else {
|
||||
panic!("Sub is not a Class")
|
||||
};
|
||||
|
||||
let context = TypingContext::global(&db);
|
||||
|
||||
let base_names: Vec<_> = class_id
|
||||
.lookup(&context)
|
||||
.bases()
|
||||
.iter()
|
||||
.map(|base_ty| format!("{}", base_ty.display(&context)))
|
||||
.collect();
|
||||
|
||||
assert_eq!(base_names, vec!["Literal[Base]"]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_method() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_file(
|
||||
"src/mod.py",
|
||||
"
|
||||
class C:
|
||||
def f(self): pass
|
||||
",
|
||||
)?;
|
||||
|
||||
let mod_file = system_path_to_file(&db, "src/mod.py").unwrap();
|
||||
let ty = public_symbol_ty_by_name(&db, mod_file, "C").unwrap();
|
||||
|
||||
let Type::Class(class_id) = ty else {
|
||||
panic!("C is not a Class");
|
||||
};
|
||||
|
||||
let context = TypingContext::global(&db);
|
||||
let member_ty = class_id.class_member(&context, &Name::new("f"));
|
||||
|
||||
let Some(Type::Function(func_id)) = member_ty else {
|
||||
panic!("C.f is not a Function");
|
||||
};
|
||||
|
||||
let function_ty = func_id.lookup(&context);
|
||||
assert_eq!(function_ty.name(), "f");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_module_member() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_files([
|
||||
("src/a.py", "import b; D = b.C"),
|
||||
("src/b.py", "class C: pass"),
|
||||
])?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "D", "Literal[C]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_literal() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_file("src/a.py", "x = 1")?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "x", "Literal[1]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_union() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_file(
|
||||
"src/a.py",
|
||||
"
|
||||
if flag:
|
||||
x = 1
|
||||
else:
|
||||
x = 2
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn literal_int_arithmetic() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_file(
|
||||
"src/a.py",
|
||||
"
|
||||
a = 2 + 1
|
||||
b = a - 4
|
||||
c = a * b
|
||||
d = c / 3
|
||||
e = 5 % 3
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "a", "Literal[3]");
|
||||
assert_public_ty(&db, "src/a.py", "b", "Literal[-1]");
|
||||
assert_public_ty(&db, "src/a.py", "c", "Literal[-3]");
|
||||
assert_public_ty(&db, "src/a.py", "d", "Literal[-1]");
|
||||
assert_public_ty(&db, "src/a.py", "e", "Literal[2]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn walrus() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system()
|
||||
.write_file("src/a.py", "x = (y := 1) + 1")?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "x", "Literal[2]");
|
||||
assert_public_ty(&db, "src/a.py", "y", "Literal[1]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ifexpr() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system()
|
||||
.write_file("src/a.py", "x = 1 if flag else 2")?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ifexpr_walrus() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_file(
|
||||
"src/a.py",
|
||||
"
|
||||
y = z = 0
|
||||
x = (y := 1) if flag else (z := 2)
|
||||
a = y
|
||||
b = z
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]");
|
||||
assert_public_ty(&db, "src/a.py", "a", "Literal[0, 1]");
|
||||
assert_public_ty(&db, "src/a.py", "b", "Literal[0, 2]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ifexpr_nested() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system()
|
||||
.write_file("src/a.py", "x = 1 if flag else 2 if flag2 else 3")?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2, 3]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn none() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system()
|
||||
.write_file("src/a.py", "x = 1 if flag else None")?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "x", "Literal[1] | None");
|
||||
Ok(())
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue