[red-knot] Extract red_knot_python_semantic crate (#11926)

This commit is contained in:
Micha Reiser 2024-06-20 12:24:24 +01:00 committed by GitHub
parent ed948eaefb
commit 2dfbf118d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 125 additions and 94 deletions

View file

@ -0,0 +1,36 @@
[package]
name = "red_knot_python_semantic"
version = "0.0.0"
publish = false
authors = { workspace = true }
edition = { workspace = true }
rust-version = { workspace = true }
homepage = { workspace = true }
documentation = { workspace = true }
repository = { workspace = true }
license = { workspace = true }
[dependencies]
ruff_db = { workspace = true }
ruff_index = { workspace = true }
ruff_python_ast = { workspace = true }
ruff_python_stdlib = { workspace = true }
ruff_text_size = { workspace = true }
bitflags = { workspace = true }
indexmap = { workspace = true }
salsa = { workspace = true }
smallvec = { workspace = true }
smol_str = { workspace = true }
tracing = { workspace = true }
rustc-hash = { workspace = true }
hashbrown = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
ruff_python_parser = { workspace = true }
tempfile = { workspace = true }
[lints]
workspace = true

View file

@ -0,0 +1,162 @@
use std::hash::Hash;
use std::ops::Deref;
use ruff_db::parsed::ParsedModule;
/// Ref-counted owned reference to an AST node.
///
/// The type holds an owned reference to the node's ref-counted [`ParsedModule`].
/// Holding on to the node's [`ParsedModule`] guarantees that the reference to the
/// node must still be valid.
///
/// Holding on to any [`AstNodeRef`] prevents the [`ParsedModule`] from being released.
///
/// ## Equality
/// Two `AstNodeRef` are considered equal if their wrapped nodes are equal.
#[derive(Clone)]
pub struct AstNodeRef<T> {
/// Owned reference to the node's [`ParsedModule`].
///
/// The node's reference is guaranteed to remain valid as long as it's enclosing
/// [`ParsedModule`] is alive.
_parsed: ParsedModule,
/// Pointer to the referenced node.
node: std::ptr::NonNull<T>,
}
#[allow(unsafe_code)]
impl<T> AstNodeRef<T> {
/// Creates a new `AstNodeRef` that reference `node`. The `parsed` is the [`ParsedModule`] to which
/// the `AstNodeRef` belongs.
///
/// ## Safety
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the [`ParsedModule`] to
/// which `node` belongs. It's the caller's responsibility to ensure that the invariant `node belongs to parsed` is upheld.
pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
Self {
_parsed: parsed,
node: std::ptr::NonNull::from(node),
}
}
/// Returns a reference to the wrapped node.
pub fn node(&self) -> &T {
// SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still alive
// and not moved.
unsafe { self.node.as_ref() }
}
}
impl<T> Deref for AstNodeRef<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.node()
}
}
impl<T> std::fmt::Debug for AstNodeRef<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("AstNodeRef").field(&self.node()).finish()
}
}
impl<T> PartialEq for AstNodeRef<T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.node().eq(other.node())
}
}
impl<T> Eq for AstNodeRef<T> where T: Eq {}
impl<T> Hash for AstNodeRef<T>
where
T: Hash,
{
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.node().hash(state);
}
}
#[allow(unsafe_code)]
unsafe impl<T> Send for AstNodeRef<T> where T: Send {}
#[allow(unsafe_code)]
unsafe impl<T> Sync for AstNodeRef<T> where T: Sync {}
#[cfg(test)]
mod tests {
use crate::ast_node_ref::AstNodeRef;
use ruff_db::parsed::ParsedModule;
use ruff_python_ast::PySourceType;
use ruff_python_parser::parse_unchecked_source;
#[test]
#[allow(unsafe_code)]
fn equality() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw.clone());
let stmt = &parsed.syntax().body[0];
let node1 = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let node2 = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
assert_eq!(node1, node2);
// Compare from different trees
let cloned = ParsedModule::new(parsed_raw);
let stmt_cloned = &cloned.syntax().body[0];
let cloned_node = unsafe { AstNodeRef::new(cloned.clone(), stmt_cloned) };
assert_eq!(node1, cloned_node);
let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
let other = ParsedModule::new(other_raw);
let other_stmt = &other.syntax().body[0];
let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) };
assert_ne!(node1, other_node);
}
#[allow(unsafe_code)]
#[test]
fn inequality() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw.clone());
let stmt = &parsed.syntax().body[0];
let node = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
let other = ParsedModule::new(other_raw);
let other_stmt = &other.syntax().body[0];
let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) };
assert_ne!(node, other_node);
}
#[test]
#[allow(unsafe_code)]
fn debug() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw.clone());
let stmt = &parsed.syntax().body[0];
let stmt_node = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let debug = format!("{stmt_node:?}");
assert_eq!(debug, format!("AstNodeRef({stmt:?})"));
}
}

View file

@ -0,0 +1,265 @@
use salsa::DbWithJar;
use ruff_db::{Db as SourceDb, Upcast};
use crate::module::resolver::{
file_to_module, internal::ModuleNameIngredient, internal::ModuleResolverSearchPaths,
resolve_module_query,
};
use crate::semantic_index::symbol::{public_symbols_map, scopes_map, PublicSymbolId, ScopeId};
use crate::semantic_index::{root_scope, semantic_index, symbol_table};
use crate::types::{infer_types, public_symbol_ty};
#[salsa::jar(db=Db)]
pub struct Jar(
ModuleNameIngredient,
ModuleResolverSearchPaths,
ScopeId,
PublicSymbolId,
symbol_table,
resolve_module_query,
file_to_module,
scopes_map,
root_scope,
semantic_index,
infer_types,
public_symbol_ty,
public_symbols_map,
);
/// Database giving access to semantic information about a Python program.
pub trait Db: SourceDb + DbWithJar<Jar> + Upcast<dyn SourceDb> {}
#[cfg(test)]
pub(crate) mod tests {
use std::fmt::Formatter;
use std::marker::PhantomData;
use std::sync::Arc;
use salsa::ingredient::Ingredient;
use salsa::storage::HasIngredientsFor;
use salsa::{AsId, DebugWithDb};
use ruff_db::file_system::{FileSystem, MemoryFileSystem, OsFileSystem};
use ruff_db::vfs::Vfs;
use ruff_db::{Db as SourceDb, Jar as SourceJar, Upcast};
use super::{Db, Jar};
#[salsa::db(Jar, SourceJar)]
pub(crate) struct TestDb {
storage: salsa::Storage<Self>,
vfs: Vfs,
file_system: TestFileSystem,
events: std::sync::Arc<std::sync::Mutex<Vec<salsa::Event>>>,
}
impl TestDb {
pub(crate) fn new() -> Self {
Self {
storage: salsa::Storage::default(),
file_system: TestFileSystem::Memory(MemoryFileSystem::default()),
events: std::sync::Arc::default(),
vfs: Vfs::with_stubbed_vendored(),
}
}
/// Returns the memory file system.
///
/// ## Panics
/// If this test db isn't using a memory file system.
pub(crate) fn memory_file_system(&self) -> &MemoryFileSystem {
if let TestFileSystem::Memory(fs) = &self.file_system {
fs
} else {
panic!("The test db is not using a memory file system");
}
}
/// Uses the real file system instead of the memory file system.
///
/// This useful for testing advanced file system features like permissions, symlinks, etc.
///
/// Note that any files written to the memory file system won't be copied over.
#[allow(unused)]
pub(crate) fn with_os_file_system(&mut self) {
self.file_system = TestFileSystem::Os(OsFileSystem);
}
#[allow(unused)]
pub(crate) fn vfs_mut(&mut self) -> &mut Vfs {
&mut self.vfs
}
/// Takes the salsa events.
///
/// ## Panics
/// If there are any pending salsa snapshots.
pub(crate) fn take_salsa_events(&mut self) -> Vec<salsa::Event> {
let inner = Arc::get_mut(&mut self.events).expect("no pending salsa snapshots");
let events = inner.get_mut().unwrap();
std::mem::take(&mut *events)
}
/// Clears the salsa events.
///
/// ## Panics
/// If there are any pending salsa snapshots.
pub(crate) fn clear_salsa_events(&mut self) {
self.take_salsa_events();
}
}
impl SourceDb for TestDb {
fn file_system(&self) -> &dyn FileSystem {
match &self.file_system {
TestFileSystem::Memory(fs) => fs,
TestFileSystem::Os(fs) => fs,
}
}
fn vfs(&self) -> &Vfs {
&self.vfs
}
}
impl Upcast<dyn SourceDb> for TestDb {
fn upcast(&self) -> &(dyn SourceDb + 'static) {
self
}
}
impl Db for TestDb {}
impl salsa::Database for TestDb {
fn salsa_event(&self, event: salsa::Event) {
tracing::trace!("event: {:?}", event.debug(self));
let mut events = self.events.lock().unwrap();
events.push(event);
}
}
impl salsa::ParallelDatabase for TestDb {
fn snapshot(&self) -> salsa::Snapshot<Self> {
salsa::Snapshot::new(Self {
storage: self.storage.snapshot(),
vfs: self.vfs.snapshot(),
file_system: match &self.file_system {
TestFileSystem::Memory(memory) => TestFileSystem::Memory(memory.snapshot()),
TestFileSystem::Os(fs) => TestFileSystem::Os(fs.snapshot()),
},
events: self.events.clone(),
})
}
}
enum TestFileSystem {
Memory(MemoryFileSystem),
#[allow(unused)]
Os(OsFileSystem),
}
pub(crate) fn assert_will_run_function_query<C, Db, Jar>(
db: &Db,
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
key: C::Key,
events: &[salsa::Event],
) where
C: salsa::function::Configuration<Jar = Jar>
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
Jar: HasIngredientsFor<C>,
Db: salsa::DbWithJar<Jar>,
C::Key: AsId,
{
will_run_function_query(db, to_function, key, events, true);
}
pub(crate) fn assert_will_not_run_function_query<C, Db, Jar>(
db: &Db,
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
key: C::Key,
events: &[salsa::Event],
) where
C: salsa::function::Configuration<Jar = Jar>
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
Jar: HasIngredientsFor<C>,
Db: salsa::DbWithJar<Jar>,
C::Key: AsId,
{
will_run_function_query(db, to_function, key, events, false);
}
fn will_run_function_query<C, Db, Jar>(
db: &Db,
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
key: C::Key,
events: &[salsa::Event],
should_run: bool,
) where
C: salsa::function::Configuration<Jar = Jar>
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
Jar: HasIngredientsFor<C>,
Db: salsa::DbWithJar<Jar>,
C::Key: AsId,
{
let (jar, _) =
<_ as salsa::storage::HasJar<<C as salsa::storage::IngredientsFor>::Jar>>::jar(db);
let ingredient = jar.ingredient();
let function_ingredient = to_function(ingredient);
let ingredient_index =
<salsa::function::FunctionIngredient<C> as Ingredient<Db>>::ingredient_index(
function_ingredient,
);
let did_run = events.iter().any(|event| {
if let salsa::EventKind::WillExecute { database_key } = event.kind {
database_key.ingredient_index() == ingredient_index
&& database_key.key_index() == key.as_id()
} else {
false
}
});
if should_run && !did_run {
panic!(
"Expected query {:?} to run but it didn't",
DebugIdx {
db: PhantomData::<Db>,
value_id: key.as_id(),
ingredient: function_ingredient,
}
);
} else if !should_run && did_run {
panic!(
"Expected query {:?} not to run but it did",
DebugIdx {
db: PhantomData::<Db>,
value_id: key.as_id(),
ingredient: function_ingredient,
}
);
}
}
struct DebugIdx<'a, I, Db>
where
I: Ingredient<Db>,
{
value_id: salsa::Id,
ingredient: &'a I,
db: PhantomData<Db>,
}
impl<'a, I, Db> std::fmt::Debug for DebugIdx<'a, I, Db>
where
I: Ingredient<Db>,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.ingredient.fmt_index(Some(self.value_id), f)
}
}
}

View file

@ -0,0 +1,13 @@
pub mod ast_node_ref;
mod db;
pub mod module;
pub mod name;
mod node_key;
pub mod semantic_index;
pub mod types;
type FxIndexSet<V> = indexmap::set::IndexSet<V, BuildHasherDefault<FxHasher>>;
pub use db::{Db, Jar};
use rustc_hash::FxHasher;
use std::hash::BuildHasherDefault;

View file

@ -0,0 +1,10 @@
use std::hash::BuildHasherDefault;
use rustc_hash::FxHasher;
pub mod ast_node_ref;
mod node_key;
pub mod semantic_index;
pub mod types;
pub(crate) type FxIndexSet<V> = indexmap::set::IndexSet<V, BuildHasherDefault<FxHasher>>;

View file

@ -0,0 +1,332 @@
use std::fmt::Formatter;
use std::ops::Deref;
use std::sync::Arc;
use ruff_db::file_system::FileSystemPath;
use ruff_db::vfs::{VfsFile, VfsPath};
use ruff_python_stdlib::identifiers::is_identifier;
use crate::Db;
pub mod resolver;
/// A module name, e.g. `foo.bar`.
///
/// Always normalized to the absolute form (never a relative module name, i.e., never `.foo`).
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct ModuleName(smol_str::SmolStr);
impl ModuleName {
/// Creates a new module name for `name`. Returns `Some` if `name` is a valid, absolute
/// module name and `None` otherwise.
///
/// The module name is invalid if:
///
/// * The name is empty
/// * The name is relative
/// * The name ends with a `.`
/// * The name contains a sequence of multiple dots
/// * A component of a name (the part between two dots) isn't a valid python identifier.
#[inline]
pub fn new(name: &str) -> Option<Self> {
Self::new_from_smol(smol_str::SmolStr::new(name))
}
/// Creates a new module name for `name` where `name` is a static string.
/// Returns `Some` if `name` is a valid, absolute module name and `None` otherwise.
///
/// The module name is invalid if:
///
/// * The name is empty
/// * The name is relative
/// * The name ends with a `.`
/// * The name contains a sequence of multiple dots
/// * A component of a name (the part between two dots) isn't a valid python identifier.
///
/// ## Examples
///
/// ```
/// use red_knot_python_semantic::module::ModuleName;
///
/// assert_eq!(ModuleName::new_static("foo.bar").as_deref(), Some("foo.bar"));
/// assert_eq!(ModuleName::new_static(""), None);
/// assert_eq!(ModuleName::new_static("..foo"), None);
/// assert_eq!(ModuleName::new_static(".foo"), None);
/// assert_eq!(ModuleName::new_static("foo."), None);
/// assert_eq!(ModuleName::new_static("foo..bar"), None);
/// assert_eq!(ModuleName::new_static("2000"), None);
/// ```
#[inline]
pub fn new_static(name: &'static str) -> Option<Self> {
Self::new_from_smol(smol_str::SmolStr::new_static(name))
}
fn new_from_smol(name: smol_str::SmolStr) -> Option<Self> {
if name.is_empty() {
return None;
}
if name.split('.').all(is_identifier) {
Some(Self(name))
} else {
None
}
}
/// An iterator over the components of the module name:
///
/// # Examples
///
/// ```
/// use red_knot_python_semantic::module::ModuleName;
///
/// assert_eq!(ModuleName::new_static("foo.bar.baz").unwrap().components().collect::<Vec<_>>(), vec!["foo", "bar", "baz"]);
/// ```
pub fn components(&self) -> impl DoubleEndedIterator<Item = &str> {
self.0.split('.')
}
/// The name of this module's immediate parent, if it has a parent.
///
/// # Examples
///
/// ```
/// use red_knot_python_semantic::module::ModuleName;
///
/// assert_eq!(ModuleName::new_static("foo.bar").unwrap().parent(), Some(ModuleName::new_static("foo").unwrap()));
/// assert_eq!(ModuleName::new_static("foo.bar.baz").unwrap().parent(), Some(ModuleName::new_static("foo.bar").unwrap()));
/// assert_eq!(ModuleName::new_static("root").unwrap().parent(), None);
/// ```
pub fn parent(&self) -> Option<ModuleName> {
let (parent, _) = self.0.rsplit_once('.')?;
Some(Self(smol_str::SmolStr::new(parent)))
}
/// Returns `true` if the name starts with `other`.
///
/// This is equivalent to checking if `self` is a sub-module of `other`.
///
/// # Examples
///
/// ```
/// use red_knot_python_semantic::module::ModuleName;
///
/// assert!(ModuleName::new_static("foo.bar").unwrap().starts_with(&ModuleName::new_static("foo").unwrap()));
///
/// assert!(!ModuleName::new_static("foo.bar").unwrap().starts_with(&ModuleName::new_static("bar").unwrap()));
/// assert!(!ModuleName::new_static("foo_bar").unwrap().starts_with(&ModuleName::new_static("foo").unwrap()));
/// ```
pub fn starts_with(&self, other: &ModuleName) -> bool {
let mut self_components = self.components();
let other_components = other.components();
for other_component in other_components {
if self_components.next() != Some(other_component) {
return false;
}
}
true
}
#[inline]
pub fn as_str(&self) -> &str {
&self.0
}
fn from_relative_path(path: &FileSystemPath) -> Option<Self> {
let path = if path.ends_with("__init__.py") || path.ends_with("__init__.pyi") {
path.parent()?
} else {
path
};
let name = if let Some(parent) = path.parent() {
let mut name = String::with_capacity(path.as_str().len());
for component in parent.components() {
name.push_str(component.as_os_str().to_str()?);
name.push('.');
}
// SAFETY: Unwrap is safe here or `parent` would have returned `None`.
name.push_str(path.file_stem().unwrap());
smol_str::SmolStr::from(name)
} else {
smol_str::SmolStr::new(path.file_stem()?)
};
Some(Self(name))
}
}
impl Deref for ModuleName {
type Target = str;
#[inline]
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl PartialEq<str> for ModuleName {
fn eq(&self, other: &str) -> bool {
self.as_str() == other
}
}
impl PartialEq<ModuleName> for str {
fn eq(&self, other: &ModuleName) -> bool {
self == other.as_str()
}
}
impl std::fmt::Display for ModuleName {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
/// Representation of a Python module.
#[derive(Clone, PartialEq, Eq)]
pub struct Module {
inner: Arc<ModuleInner>,
}
impl Module {
/// The absolute name of the module (e.g. `foo.bar`)
pub fn name(&self) -> &ModuleName {
&self.inner.name
}
/// The file to the source code that defines this module
pub fn file(&self) -> VfsFile {
self.inner.file
}
/// The search path from which the module was resolved.
pub fn search_path(&self) -> &ModuleSearchPath {
&self.inner.search_path
}
/// Determine whether this module is a single-file module or a package
pub fn kind(&self) -> ModuleKind {
self.inner.kind
}
}
impl std::fmt::Debug for Module {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Module")
.field("name", &self.name())
.field("kind", &self.kind())
.field("file", &self.file())
.field("search_path", &self.search_path())
.finish()
}
}
impl salsa::DebugWithDb<dyn Db> for Module {
fn fmt(&self, f: &mut Formatter<'_>, db: &dyn Db) -> std::fmt::Result {
f.debug_struct("Module")
.field("name", &self.name())
.field("kind", &self.kind())
.field("file", &self.file().debug(db.upcast()))
.field("search_path", &self.search_path())
.finish()
}
}
#[derive(PartialEq, Eq)]
struct ModuleInner {
name: ModuleName,
kind: ModuleKind,
search_path: ModuleSearchPath,
file: VfsFile,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum ModuleKind {
/// A single-file module (e.g. `foo.py` or `foo.pyi`)
Module,
/// A python package (`foo/__init__.py` or `foo/__init__.pyi`)
Package,
}
/// A search path in which to search modules.
/// Corresponds to a path in [`sys.path`](https://docs.python.org/3/library/sys_path_init.html) at runtime.
///
/// Cloning a search path is cheap because it's an `Arc`.
#[derive(Clone, PartialEq, Eq)]
pub struct ModuleSearchPath {
inner: Arc<ModuleSearchPathInner>,
}
impl ModuleSearchPath {
pub fn new<P>(path: P, kind: ModuleSearchPathKind) -> Self
where
P: Into<VfsPath>,
{
Self {
inner: Arc::new(ModuleSearchPathInner {
path: path.into(),
kind,
}),
}
}
/// Determine whether this is a first-party, third-party or standard-library search path
pub fn kind(&self) -> ModuleSearchPathKind {
self.inner.kind
}
/// Return the location of the search path on the file system
pub fn path(&self) -> &VfsPath {
&self.inner.path
}
}
impl std::fmt::Debug for ModuleSearchPath {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ModuleSearchPath")
.field("path", &self.inner.path)
.field("kind", &self.kind())
.finish()
}
}
#[derive(Eq, PartialEq)]
struct ModuleSearchPathInner {
path: VfsPath,
kind: ModuleSearchPathKind,
}
/// Enumeration of the different kinds of search paths type checkers are expected to support.
///
/// N.B. Although we don't implement `Ord` for this enum, they are ordered in terms of the
/// priority that we want to give these modules when resolving them.
/// This is roughly [the order given in the typing spec], but typeshed's stubs
/// for the standard library are moved higher up to match Python's semantics at runtime.
///
/// [the order given in the typing spec]: https://typing.readthedocs.io/en/latest/spec/distributing.html#import-resolution-ordering
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum ModuleSearchPathKind {
/// "Extra" paths provided by the user in a config file, env var or CLI flag.
/// E.g. mypy's `MYPYPATH` env var, or pyright's `stubPath` configuration setting
Extra,
/// Files in the project we're directly being invoked on
FirstParty,
/// The `stdlib` directory of typeshed (either vendored or custom)
StandardLibrary,
/// Stubs or runtime modules installed in site-packages
SitePackagesThirdParty,
/// Vendored third-party stubs from typeshed
VendoredThirdParty,
}

View file

@ -0,0 +1,944 @@
use std::ops::Deref;
use std::sync::Arc;
use ruff_db::file_system::{FileSystem, FileSystemPath, FileSystemPathBuf};
use ruff_db::vfs::{system_path_to_file, vfs_path_to_file, VfsFile, VfsPath};
use crate::module::resolver::internal::ModuleResolverSearchPaths;
use crate::module::{
Module, ModuleInner, ModuleKind, ModuleName, ModuleSearchPath, ModuleSearchPathKind,
};
use crate::Db;
const TYPESHED_STDLIB_DIRECTORY: &str = "stdlib";
/// Configures the module search paths for the module resolver.
///
/// Must be called before calling any other module resolution functions.
pub fn set_module_resolution_settings(db: &mut dyn Db, config: ModuleResolutionSettings) {
// There's no concurrency issue here because we hold a `&mut dyn Db` reference. No other
// thread can mutate the `Db` while we're in this call, so using `try_get` to test if
// the settings have already been set is safe.
if let Some(existing) = ModuleResolverSearchPaths::try_get(db) {
existing
.set_search_paths(db)
.to(config.into_ordered_search_paths());
} else {
ModuleResolverSearchPaths::new(db, config.into_ordered_search_paths());
}
}
/// Resolves a module name to a module.
#[tracing::instrument(level = "debug", skip(db))]
pub fn resolve_module(db: &dyn Db, module_name: ModuleName) -> Option<Module> {
let interned_name = internal::ModuleNameIngredient::new(db, module_name);
resolve_module_query(db, interned_name)
}
/// Salsa query that resolves an interned [`ModuleNameIngredient`] to a module.
///
/// This query should not be called directly. Instead, use [`resolve_module`]. It only exists
/// because Salsa requires the module name to be an ingredient.
#[salsa::tracked]
pub(crate) fn resolve_module_query(
db: &dyn Db,
module_name: internal::ModuleNameIngredient,
) -> Option<Module> {
let name = module_name.name(db);
let (search_path, module_file, kind) = resolve_name(db, name)?;
let module = Module {
inner: Arc::new(ModuleInner {
name: name.clone(),
kind,
search_path,
file: module_file,
}),
};
Some(module)
}
/// Resolves the module for the given path.
///
/// Returns `None` if the path is not a module locatable via `sys.path`.
#[tracing::instrument(level = "debug", skip(db))]
pub fn path_to_module(db: &dyn Db, path: &VfsPath) -> Option<Module> {
// It's not entirely clear on first sight why this method calls `file_to_module` instead of
// it being the other way round, considering that the first thing that `file_to_module` does
// is to retrieve the file's path.
//
// The reason is that `file_to_module` is a tracked Salsa query and salsa queries require that
// all arguments are Salsa ingredients (something stored in Salsa). `Path`s aren't salsa ingredients but
// `VfsFile` is. So what we do here is to retrieve the `path`'s `VfsFile` so that we can make
// use of Salsa's caching and invalidation.
let file = vfs_path_to_file(db.upcast(), path)?;
file_to_module(db, file)
}
/// Resolves the module for the file with the given id.
///
/// Returns `None` if the file is not a module locatable via `sys.path`.
#[salsa::tracked]
#[tracing::instrument(level = "debug", skip(db))]
pub fn file_to_module(db: &dyn Db, file: VfsFile) -> Option<Module> {
let path = file.path(db.upcast());
let search_paths = module_search_paths(db);
let relative_path = search_paths
.iter()
.find_map(|root| match (root.path(), path) {
(VfsPath::FileSystem(root_path), VfsPath::FileSystem(path)) => {
let relative_path = path.strip_prefix(root_path).ok()?;
Some(relative_path)
}
(VfsPath::Vendored(_), VfsPath::Vendored(_)) => {
todo!("Add support for vendored modules")
}
(VfsPath::Vendored(_), VfsPath::FileSystem(_))
| (VfsPath::FileSystem(_), VfsPath::Vendored(_)) => None,
})?;
let module_name = ModuleName::from_relative_path(relative_path)?;
// Resolve the module name to see if Python would resolve the name to the same path.
// If it doesn't, then that means that multiple modules have the same name in different
// root paths, but that the module corresponding to `path` is in a lower priority search path,
// in which case we ignore it.
let module = resolve_module(db, module_name)?;
if file == module.file() {
Some(module)
} else {
// This path is for a module with the same name but with a different precedence. For example:
// ```
// src/foo.py
// src/foo/__init__.py
// ```
// The module name of `src/foo.py` is `foo`, but the module loaded by Python is `src/foo/__init__.py`.
// That means we need to ignore `src/foo.py` even though it resolves to the same module name.
None
}
}
/// Configures the [`ModuleSearchPath`]s that are used to resolve modules.
#[derive(Eq, PartialEq, Debug)]
pub struct ModuleResolutionSettings {
/// List of user-provided paths that should take first priority in the module resolution.
/// Examples in other type checkers are mypy's MYPYPATH environment variable,
/// or pyright's stubPath configuration setting.
pub extra_paths: Vec<FileSystemPathBuf>,
/// The root of the workspace, used for finding first-party modules.
pub workspace_root: FileSystemPathBuf,
/// The path to the user's `site-packages` directory, where third-party packages from ``PyPI`` are installed.
pub site_packages: Option<FileSystemPathBuf>,
/// Optional path to standard-library typeshed stubs.
/// Currently this has to be a directory that exists on disk.
///
/// (TODO: fall back to vendored stubs if no custom directory is provided.)
pub custom_typeshed: Option<FileSystemPathBuf>,
}
impl ModuleResolutionSettings {
/// Implementation of PEP 561's module resolution order
/// (with some small, deliberate, differences)
fn into_ordered_search_paths(self) -> OrderedSearchPaths {
let ModuleResolutionSettings {
extra_paths,
workspace_root,
site_packages,
custom_typeshed,
} = self;
let mut paths: Vec<_> = extra_paths
.into_iter()
.map(|path| ModuleSearchPath::new(path, ModuleSearchPathKind::Extra))
.collect();
paths.push(ModuleSearchPath::new(
workspace_root,
ModuleSearchPathKind::FirstParty,
));
// TODO fallback to vendored typeshed stubs if no custom typeshed directory is provided by the user
if let Some(custom_typeshed) = custom_typeshed {
paths.push(ModuleSearchPath::new(
custom_typeshed.join(TYPESHED_STDLIB_DIRECTORY),
ModuleSearchPathKind::StandardLibrary,
));
}
// TODO vendor typeshed's third-party stubs as well as the stdlib and fallback to them as a final step
if let Some(site_packages) = site_packages {
paths.push(ModuleSearchPath::new(
site_packages,
ModuleSearchPathKind::SitePackagesThirdParty,
));
}
OrderedSearchPaths(paths)
}
}
/// A resolved module resolution order, implementing PEP 561
/// (with some small, deliberate differences)
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub(crate) struct OrderedSearchPaths(Vec<ModuleSearchPath>);
impl Deref for OrderedSearchPaths {
type Target = [ModuleSearchPath];
fn deref(&self) -> &Self::Target {
&self.0
}
}
// The singleton methods generated by salsa are all `pub` instead of `pub(crate)` which triggers
// `unreachable_pub`. Work around this by creating a module and allow `unreachable_pub` for it.
// Salsa also generates uses to `_db` variables for `interned` which triggers `clippy::used_underscore_binding`. Suppress that too
// TODO(micha): Contribute a fix for this upstream where the singleton methods have the same visibility as the struct.
#[allow(unreachable_pub, clippy::used_underscore_binding)]
pub(crate) mod internal {
use crate::module::resolver::OrderedSearchPaths;
use crate::module::ModuleName;
#[salsa::input(singleton)]
pub(crate) struct ModuleResolverSearchPaths {
#[return_ref]
pub(super) search_paths: OrderedSearchPaths,
}
/// A thin wrapper around `ModuleName` to make it a Salsa ingredient.
///
/// This is needed because Salsa requires that all query arguments are salsa ingredients.
#[salsa::interned]
pub(crate) struct ModuleNameIngredient {
#[return_ref]
pub(super) name: ModuleName,
}
}
fn module_search_paths(db: &dyn Db) -> &[ModuleSearchPath] {
ModuleResolverSearchPaths::get(db).search_paths(db)
}
/// Given a module name and a list of search paths in which to lookup modules,
/// attempt to resolve the module name
fn resolve_name(db: &dyn Db, name: &ModuleName) -> Option<(ModuleSearchPath, VfsFile, ModuleKind)> {
let search_paths = module_search_paths(db);
for search_path in search_paths {
let mut components = name.components();
let module_name = components.next_back()?;
let VfsPath::FileSystem(fs_search_path) = search_path.path() else {
todo!("Vendored search paths are not yet supported");
};
match resolve_package(db.file_system(), fs_search_path, components) {
Ok(resolved_package) => {
let mut package_path = resolved_package.path;
package_path.push(module_name);
// Must be a `__init__.pyi` or `__init__.py` or it isn't a package.
let kind = if db.file_system().is_directory(&package_path) {
package_path.push("__init__");
ModuleKind::Package
} else {
ModuleKind::Module
};
// TODO Implement full https://peps.python.org/pep-0561/#type-checker-module-resolution-order resolution
let stub = package_path.with_extension("pyi");
if let Some(stub) = system_path_to_file(db.upcast(), &stub) {
return Some((search_path.clone(), stub, kind));
}
let module = package_path.with_extension("py");
if let Some(module) = system_path_to_file(db.upcast(), &module) {
return Some((search_path.clone(), module, kind));
}
// For regular packages, don't search the next search path. All files of that
// package must be in the same location
if resolved_package.kind.is_regular_package() {
return None;
}
}
Err(parent_kind) => {
if parent_kind.is_regular_package() {
// For regular packages, don't search the next search path.
return None;
}
}
}
}
None
}
fn resolve_package<'a, I>(
fs: &dyn FileSystem,
module_search_path: &FileSystemPath,
components: I,
) -> Result<ResolvedPackage, PackageKind>
where
I: Iterator<Item = &'a str>,
{
let mut package_path = module_search_path.to_path_buf();
// `true` if inside a folder that is a namespace package (has no `__init__.py`).
// Namespace packages are special because they can be spread across multiple search paths.
// https://peps.python.org/pep-0420/
let mut in_namespace_package = false;
// `true` if resolving a sub-package. For example, `true` when resolving `bar` of `foo.bar`.
let mut in_sub_package = false;
// For `foo.bar.baz`, test that `foo` and `baz` both contain a `__init__.py`.
for folder in components {
package_path.push(folder);
let has_init_py = fs.is_file(&package_path.join("__init__.py"))
|| fs.is_file(&package_path.join("__init__.pyi"));
if has_init_py {
in_namespace_package = false;
} else if fs.is_directory(&package_path) {
// A directory without an `__init__.py` is a namespace package, continue with the next folder.
in_namespace_package = true;
} else if in_namespace_package {
// Package not found but it is part of a namespace package.
return Err(PackageKind::Namespace);
} else if in_sub_package {
// A regular sub package wasn't found.
return Err(PackageKind::Regular);
} else {
// We couldn't find `foo` for `foo.bar.baz`, search the next search path.
return Err(PackageKind::Root);
}
in_sub_package = true;
}
let kind = if in_namespace_package {
PackageKind::Namespace
} else if in_sub_package {
PackageKind::Regular
} else {
PackageKind::Root
};
Ok(ResolvedPackage {
kind,
path: package_path,
})
}
#[derive(Debug)]
struct ResolvedPackage {
path: FileSystemPathBuf,
kind: PackageKind,
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
enum PackageKind {
/// A root package or module. E.g. `foo` in `foo.bar.baz` or just `foo`.
Root,
/// A regular sub-package where the parent contains an `__init__.py`.
///
/// For example, `bar` in `foo.bar` when the `foo` directory contains an `__init__.py`.
Regular,
/// A sub-package in a namespace package. A namespace package is a package without an `__init__.py`.
///
/// For example, `bar` in `foo.bar` if the `foo` directory contains no `__init__.py`.
Namespace,
}
impl PackageKind {
const fn is_regular_package(self) -> bool {
matches!(self, PackageKind::Regular)
}
}
#[cfg(test)]
mod tests {
use ruff_db::file_system::{FileSystemPath, FileSystemPathBuf};
use ruff_db::vfs::{system_path_to_file, VfsFile, VfsPath};
use crate::db::tests::TestDb;
use crate::module::{ModuleKind, ModuleName};
use super::{
path_to_module, resolve_module, set_module_resolution_settings, ModuleResolutionSettings,
TYPESHED_STDLIB_DIRECTORY,
};
struct TestCase {
db: TestDb,
src: FileSystemPathBuf,
custom_typeshed: FileSystemPathBuf,
site_packages: FileSystemPathBuf,
}
fn create_resolver() -> std::io::Result<TestCase> {
let mut db = TestDb::new();
let src = FileSystemPath::new("src").to_path_buf();
let site_packages = FileSystemPath::new("site_packages").to_path_buf();
let custom_typeshed = FileSystemPath::new("typeshed").to_path_buf();
let fs = db.memory_file_system();
fs.create_directory_all(&src)?;
fs.create_directory_all(&site_packages)?;
fs.create_directory_all(&custom_typeshed)?;
let settings = ModuleResolutionSettings {
extra_paths: vec![],
workspace_root: src.clone(),
site_packages: Some(site_packages.clone()),
custom_typeshed: Some(custom_typeshed.clone()),
};
set_module_resolution_settings(&mut db, settings);
Ok(TestCase {
db,
src,
custom_typeshed,
site_packages,
})
}
#[test]
fn first_party_module() -> anyhow::Result<()> {
let TestCase { db, src, .. } = create_resolver()?;
let foo_module_name = ModuleName::new_static("foo").unwrap();
let foo_path = src.join("foo.py");
db.memory_file_system()
.write_file(&foo_path, "print('Hello, world!')")?;
let foo_module = resolve_module(&db, foo_module_name.clone()).unwrap();
assert_eq!(
Some(&foo_module),
resolve_module(&db, foo_module_name.clone()).as_ref()
);
assert_eq!("foo", foo_module.name());
assert_eq!(&src, foo_module.search_path().path());
assert_eq!(ModuleKind::Module, foo_module.kind());
assert_eq!(&foo_path, foo_module.file().path(&db));
assert_eq!(
Some(foo_module),
path_to_module(&db, &VfsPath::FileSystem(foo_path))
);
Ok(())
}
#[test]
fn stdlib() -> anyhow::Result<()> {
let TestCase {
db,
custom_typeshed,
..
} = create_resolver()?;
let stdlib_dir = custom_typeshed.join(TYPESHED_STDLIB_DIRECTORY);
let functools_path = stdlib_dir.join("functools.py");
db.memory_file_system()
.write_file(&functools_path, "def update_wrapper(): ...")?;
let functools_module_name = ModuleName::new_static("functools").unwrap();
let functools_module = resolve_module(&db, functools_module_name.clone()).unwrap();
assert_eq!(
Some(&functools_module),
resolve_module(&db, functools_module_name).as_ref()
);
assert_eq!(&stdlib_dir, functools_module.search_path().path());
assert_eq!(ModuleKind::Module, functools_module.kind());
assert_eq!(&functools_path.clone(), functools_module.file().path(&db));
assert_eq!(
Some(functools_module),
path_to_module(&db, &VfsPath::FileSystem(functools_path))
);
Ok(())
}
#[test]
fn first_party_precedence_over_stdlib() -> anyhow::Result<()> {
let TestCase {
db,
src,
custom_typeshed,
..
} = create_resolver()?;
let stdlib_dir = custom_typeshed.join(TYPESHED_STDLIB_DIRECTORY);
let stdlib_functools_path = stdlib_dir.join("functools.py");
let first_party_functools_path = src.join("functools.py");
db.memory_file_system().write_files([
(&stdlib_functools_path, "def update_wrapper(): ..."),
(&first_party_functools_path, "def update_wrapper(): ..."),
])?;
let functools_module_name = ModuleName::new_static("functools").unwrap();
let functools_module = resolve_module(&db, functools_module_name.clone()).unwrap();
assert_eq!(
Some(&functools_module),
resolve_module(&db, functools_module_name).as_ref()
);
assert_eq!(&src, functools_module.search_path().path());
assert_eq!(ModuleKind::Module, functools_module.kind());
assert_eq!(
&first_party_functools_path.clone(),
functools_module.file().path(&db)
);
assert_eq!(
Some(functools_module),
path_to_module(&db, &VfsPath::FileSystem(first_party_functools_path))
);
Ok(())
}
// TODO: Port typeshed test case. Porting isn't possible at the moment because the vendored zip
// is part of the red knot crate
// #[test]
// fn typeshed_zip_created_at_build_time() -> anyhow::Result<()> {
// // The file path here is hardcoded in this crate's `build.rs` script.
// // Luckily this crate will fail to build if this file isn't available at build time.
// const TYPESHED_ZIP_BYTES: &[u8] =
// include_bytes!(concat!(env!("OUT_DIR"), "/zipped_typeshed.zip"));
// assert!(!TYPESHED_ZIP_BYTES.is_empty());
// let mut typeshed_zip_archive = ZipArchive::new(Cursor::new(TYPESHED_ZIP_BYTES))?;
//
// let path_to_functools = Path::new("stdlib").join("functools.pyi");
// let mut functools_module_stub = typeshed_zip_archive
// .by_name(path_to_functools.to_str().unwrap())
// .unwrap();
// assert!(functools_module_stub.is_file());
//
// let mut functools_module_stub_source = String::new();
// functools_module_stub.read_to_string(&mut functools_module_stub_source)?;
//
// assert!(functools_module_stub_source.contains("def update_wrapper("));
// Ok(())
// }
#[test]
fn resolve_package() -> anyhow::Result<()> {
let TestCase { src, db, .. } = create_resolver()?;
let foo_dir = src.join("foo");
let foo_path = foo_dir.join("__init__.py");
db.memory_file_system()
.write_file(&foo_path, "print('Hello, world!')")?;
let foo_module = resolve_module(&db, ModuleName::new_static("foo").unwrap()).unwrap();
assert_eq!("foo", foo_module.name());
assert_eq!(&src, foo_module.search_path().path());
assert_eq!(&foo_path, foo_module.file().path(&db));
assert_eq!(
Some(&foo_module),
path_to_module(&db, &VfsPath::FileSystem(foo_path)).as_ref()
);
// Resolving by directory doesn't resolve to the init file.
assert_eq!(None, path_to_module(&db, &VfsPath::FileSystem(foo_dir)));
Ok(())
}
#[test]
fn package_priority_over_module() -> anyhow::Result<()> {
let TestCase { db, src, .. } = create_resolver()?;
let foo_dir = src.join("foo");
let foo_init = foo_dir.join("__init__.py");
db.memory_file_system()
.write_file(&foo_init, "print('Hello, world!')")?;
let foo_py = src.join("foo.py");
db.memory_file_system()
.write_file(&foo_py, "print('Hello, world!')")?;
let foo_module = resolve_module(&db, ModuleName::new_static("foo").unwrap()).unwrap();
assert_eq!(&src, foo_module.search_path().path());
assert_eq!(&foo_init, foo_module.file().path(&db));
assert_eq!(ModuleKind::Package, foo_module.kind());
assert_eq!(
Some(foo_module),
path_to_module(&db, &VfsPath::FileSystem(foo_init))
);
assert_eq!(None, path_to_module(&db, &VfsPath::FileSystem(foo_py)));
Ok(())
}
#[test]
fn typing_stub_over_module() -> anyhow::Result<()> {
let TestCase { db, src, .. } = create_resolver()?;
let foo_stub = src.join("foo.pyi");
let foo_py = src.join("foo.py");
db.memory_file_system()
.write_files([(&foo_stub, "x: int"), (&foo_py, "print('Hello, world!')")])?;
let foo = resolve_module(&db, ModuleName::new_static("foo").unwrap()).unwrap();
assert_eq!(&src, foo.search_path().path());
assert_eq!(&foo_stub, foo.file().path(&db));
assert_eq!(
Some(foo),
path_to_module(&db, &VfsPath::FileSystem(foo_stub))
);
assert_eq!(None, path_to_module(&db, &VfsPath::FileSystem(foo_py)));
Ok(())
}
#[test]
fn sub_packages() -> anyhow::Result<()> {
let TestCase { db, src, .. } = create_resolver()?;
let foo = src.join("foo");
let bar = foo.join("bar");
let baz = bar.join("baz.py");
db.memory_file_system().write_files([
(&foo.join("__init__.py"), ""),
(&bar.join("__init__.py"), ""),
(&baz, "print('Hello, world!')"),
])?;
let baz_module =
resolve_module(&db, ModuleName::new_static("foo.bar.baz").unwrap()).unwrap();
assert_eq!(&src, baz_module.search_path().path());
assert_eq!(&baz, baz_module.file().path(&db));
assert_eq!(
Some(baz_module),
path_to_module(&db, &VfsPath::FileSystem(baz))
);
Ok(())
}
#[test]
fn namespace_package() -> anyhow::Result<()> {
let TestCase {
db,
src,
site_packages,
..
} = create_resolver()?;
// From [PEP420](https://peps.python.org/pep-0420/#nested-namespace-packages).
// But uses `src` for `project1` and `site_packages2` for `project2`.
// ```
// src
// parent
// child
// one.py
// site_packages
// parent
// child
// two.py
// ```
let parent1 = src.join("parent");
let child1 = parent1.join("child");
let one = child1.join("one.py");
let parent2 = site_packages.join("parent");
let child2 = parent2.join("child");
let two = child2.join("two.py");
db.memory_file_system().write_files([
(&one, "print('Hello, world!')"),
(&two, "print('Hello, world!')"),
])?;
let one_module =
resolve_module(&db, ModuleName::new_static("parent.child.one").unwrap()).unwrap();
assert_eq!(
Some(one_module),
path_to_module(&db, &VfsPath::FileSystem(one))
);
let two_module =
resolve_module(&db, ModuleName::new_static("parent.child.two").unwrap()).unwrap();
assert_eq!(
Some(two_module),
path_to_module(&db, &VfsPath::FileSystem(two))
);
Ok(())
}
#[test]
fn regular_package_in_namespace_package() -> anyhow::Result<()> {
let TestCase {
db,
src,
site_packages,
..
} = create_resolver()?;
// Adopted test case from the [PEP420 examples](https://peps.python.org/pep-0420/#nested-namespace-packages).
// The `src/parent/child` package is a regular package. Therefore, `site_packages/parent/child/two.py` should not be resolved.
// ```
// src
// parent
// child
// one.py
// site_packages
// parent
// child
// two.py
// ```
let parent1 = src.join("parent");
let child1 = parent1.join("child");
let one = child1.join("one.py");
let parent2 = site_packages.join("parent");
let child2 = parent2.join("child");
let two = child2.join("two.py");
db.memory_file_system().write_files([
(&child1.join("__init__.py"), "print('Hello, world!')"),
(&one, "print('Hello, world!')"),
(&two, "print('Hello, world!')"),
])?;
let one_module =
resolve_module(&db, ModuleName::new_static("parent.child.one").unwrap()).unwrap();
assert_eq!(
Some(one_module),
path_to_module(&db, &VfsPath::FileSystem(one))
);
assert_eq!(
None,
resolve_module(&db, ModuleName::new_static("parent.child.two").unwrap())
);
Ok(())
}
#[test]
fn module_search_path_priority() -> anyhow::Result<()> {
let TestCase {
db,
src,
site_packages,
..
} = create_resolver()?;
let foo_src = src.join("foo.py");
let foo_site_packages = site_packages.join("foo.py");
db.memory_file_system()
.write_files([(&foo_src, ""), (&foo_site_packages, "")])?;
let foo_module = resolve_module(&db, ModuleName::new_static("foo").unwrap()).unwrap();
assert_eq!(&src, foo_module.search_path().path());
assert_eq!(&foo_src, foo_module.file().path(&db));
assert_eq!(
Some(foo_module),
path_to_module(&db, &VfsPath::FileSystem(foo_src))
);
assert_eq!(
None,
path_to_module(&db, &VfsPath::FileSystem(foo_site_packages))
);
Ok(())
}
#[test]
#[cfg(target_family = "unix")]
fn symlink() -> anyhow::Result<()> {
let TestCase {
mut db,
src,
site_packages,
custom_typeshed,
} = create_resolver()?;
db.with_os_file_system();
let temp_dir = tempfile::tempdir()?;
let root = FileSystemPath::from_std_path(temp_dir.path()).unwrap();
let src = root.join(src);
let site_packages = root.join(site_packages);
let custom_typeshed = root.join(custom_typeshed);
let foo = src.join("foo.py");
let bar = src.join("bar.py");
std::fs::create_dir_all(src.as_std_path())?;
std::fs::create_dir_all(site_packages.as_std_path())?;
std::fs::create_dir_all(custom_typeshed.as_std_path())?;
std::fs::write(foo.as_std_path(), "")?;
std::os::unix::fs::symlink(foo.as_std_path(), bar.as_std_path())?;
let settings = ModuleResolutionSettings {
extra_paths: vec![],
workspace_root: src.clone(),
site_packages: Some(site_packages),
custom_typeshed: Some(custom_typeshed),
};
set_module_resolution_settings(&mut db, settings);
let foo_module = resolve_module(&db, ModuleName::new_static("foo").unwrap()).unwrap();
let bar_module = resolve_module(&db, ModuleName::new_static("bar").unwrap()).unwrap();
assert_ne!(foo_module, bar_module);
assert_eq!(&src, foo_module.search_path().path());
assert_eq!(&foo, foo_module.file().path(&db));
// `foo` and `bar` shouldn't resolve to the same file
assert_eq!(&src, bar_module.search_path().path());
assert_eq!(&bar, bar_module.file().path(&db));
assert_eq!(&foo, foo_module.file().path(&db));
assert_ne!(&foo_module, &bar_module);
assert_eq!(
Some(foo_module),
path_to_module(&db, &VfsPath::FileSystem(foo))
);
assert_eq!(
Some(bar_module),
path_to_module(&db, &VfsPath::FileSystem(bar))
);
Ok(())
}
#[test]
fn deleting_an_unrealted_file_doesnt_change_module_resolution() -> anyhow::Result<()> {
let TestCase { mut db, src, .. } = create_resolver()?;
let foo_path = src.join("foo.py");
let bar_path = src.join("bar.py");
db.memory_file_system()
.write_files([(&foo_path, "x = 1"), (&bar_path, "y = 2")])?;
let foo_module_name = ModuleName::new_static("foo").unwrap();
let foo_module = resolve_module(&db, foo_module_name.clone()).unwrap();
let bar = system_path_to_file(&db, &bar_path).expect("bar.py to exist");
db.clear_salsa_events();
// Delete `bar.py`
db.memory_file_system().remove_file(&bar_path)?;
bar.touch(&mut db);
// Re-query the foo module. The foo module should still be cached because `bar.py` isn't relevant
// for resolving `foo`.
let foo_module2 = resolve_module(&db, foo_module_name);
assert!(!db
.take_salsa_events()
.iter()
.any(|event| { matches!(event.kind, salsa::EventKind::WillExecute { .. }) }));
assert_eq!(Some(foo_module), foo_module2);
Ok(())
}
#[test]
fn adding_a_file_on_which_the_module_resolution_depends_on_invalidates_the_query(
) -> anyhow::Result<()> {
let TestCase { mut db, src, .. } = create_resolver()?;
let foo_path = src.join("foo.py");
let foo_module_name = ModuleName::new_static("foo").unwrap();
assert_eq!(resolve_module(&db, foo_module_name.clone()), None);
// Now write the foo file
db.memory_file_system().write_file(&foo_path, "x = 1")?;
VfsFile::touch_path(&mut db, &VfsPath::FileSystem(foo_path.clone()));
let foo_file = system_path_to_file(&db, &foo_path).expect("foo.py to exist");
let foo_module = resolve_module(&db, foo_module_name).expect("Foo module to resolve");
assert_eq!(foo_file, foo_module.file());
Ok(())
}
#[test]
fn removing_a_file_that_the_module_resolution_depends_on_invalidates_the_query(
) -> anyhow::Result<()> {
let TestCase { mut db, src, .. } = create_resolver()?;
let foo_path = src.join("foo.py");
let foo_init_path = src.join("foo/__init__.py");
db.memory_file_system()
.write_files([(&foo_path, "x = 1"), (&foo_init_path, "x = 2")])?;
let foo_module_name = ModuleName::new_static("foo").unwrap();
let foo_module = resolve_module(&db, foo_module_name.clone()).expect("foo module to exist");
assert_eq!(&foo_init_path, foo_module.file().path(&db));
// Delete `foo/__init__.py` and the `foo` folder. `foo` should now resolve to `foo.py`
db.memory_file_system().remove_file(&foo_init_path)?;
db.memory_file_system()
.remove_directory(foo_init_path.parent().unwrap())?;
VfsFile::touch_path(&mut db, &VfsPath::FileSystem(foo_init_path.clone()));
let foo_module = resolve_module(&db, foo_module_name).expect("Foo module to resolve");
assert_eq!(&foo_path, foo_module.file().path(&db));
Ok(())
}
}

View file

@ -0,0 +1,56 @@
use std::ops::Deref;
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct Name(smol_str::SmolStr);
impl Name {
#[inline]
pub fn new(name: &str) -> Self {
Self(smol_str::SmolStr::new(name))
}
#[inline]
pub fn new_static(name: &'static str) -> Self {
Self(smol_str::SmolStr::new_static(name))
}
pub fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl Deref for Name {
type Target = str;
#[inline]
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl<T> From<T> for Name
where
T: Into<smol_str::SmolStr>,
{
fn from(value: T) -> Self {
Self(value.into())
}
}
impl std::fmt::Display for Name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl PartialEq<str> for Name {
fn eq(&self, other: &str) -> bool {
self.as_str() == other
}
}
impl PartialEq<Name> for str {
fn eq(&self, other: &Name) -> bool {
other == self
}
}

View file

@ -0,0 +1,24 @@
use ruff_python_ast::{AnyNodeRef, NodeKind};
use ruff_text_size::{Ranged, TextRange};
/// Compact key for a node for use in a hash map.
///
/// Compares two nodes by their kind and text range.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub(super) struct NodeKey {
kind: NodeKind,
range: TextRange,
}
impl NodeKey {
pub(super) fn from_node<'a, N>(node: N) -> Self
where
N: Into<AnyNodeRef<'a>>,
{
let node = node.into();
NodeKey {
kind: node.kind(),
range: node.range(),
}
}
}

View file

@ -0,0 +1,668 @@
use std::iter::FusedIterator;
use std::sync::Arc;
use rustc_hash::FxHashMap;
use ruff_db::parsed::parsed_module;
use ruff_db::vfs::VfsFile;
use ruff_index::{IndexSlice, IndexVec};
use ruff_python_ast as ast;
use crate::node_key::NodeKey;
use crate::semantic_index::ast_ids::{AstId, AstIds, ScopeClassId, ScopeFunctionId};
use crate::semantic_index::builder::SemanticIndexBuilder;
use crate::semantic_index::symbol::{
FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable,
};
use crate::Db;
pub mod ast_ids;
mod builder;
pub mod definition;
pub mod symbol;
type SymbolMap = hashbrown::HashMap<ScopedSymbolId, (), ()>;
/// Returns the semantic index for `file`.
///
/// Prefer using [`symbol_table`] when working with symbols from a single scope.
#[salsa::tracked(return_ref, no_eq)]
pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex {
let parsed = parsed_module(db.upcast(), file);
SemanticIndexBuilder::new(parsed).build()
}
/// Returns the symbol table for a specific `scope`.
///
/// Using [`symbol_table`] over [`semantic_index`] has the advantage that
/// Salsa can avoid invalidating dependent queries if this scope's symbol table
/// is unchanged.
#[salsa::tracked]
pub(crate) fn symbol_table(db: &dyn Db, scope: ScopeId) -> Arc<SymbolTable> {
let index = semantic_index(db, scope.file(db));
index.symbol_table(scope.file_scope_id(db))
}
/// Returns the root scope of `file`.
#[salsa::tracked]
pub(crate) fn root_scope(db: &dyn Db, file: VfsFile) -> ScopeId {
FileScopeId::root().to_scope_id(db, file)
}
/// Returns the symbol with the given name in `file`'s public scope or `None` if
/// no symbol with the given name exists.
pub fn public_symbol(db: &dyn Db, file: VfsFile, name: &str) -> Option<PublicSymbolId> {
let root_scope = root_scope(db, file);
let symbol_table = symbol_table(db, root_scope);
let local = symbol_table.symbol_id_by_name(name)?;
Some(local.to_public_symbol(db, file))
}
/// The symbol tables for an entire file.
#[derive(Debug)]
pub struct SemanticIndex {
/// List of all symbol tables in this file, indexed by scope.
symbol_tables: IndexVec<FileScopeId, Arc<SymbolTable>>,
/// List of all scopes in this file.
scopes: IndexVec<FileScopeId, Scope>,
/// Maps expressions to their corresponding scope.
/// We can't use [`ExpressionId`] here, because the challenge is how to get from
/// an [`ast::Expr`] to an [`ExpressionId`] (which requires knowing the scope).
expression_scopes: FxHashMap<NodeKey, FileScopeId>,
/// Lookup table to map between node ids and ast nodes.
///
/// Note: We should not depend on this map when analysing other files or
/// changing a file invalidates all dependents.
ast_ids: IndexVec<FileScopeId, AstIds>,
/// Map from scope to the node that introduces the scope.
scope_nodes: IndexVec<FileScopeId, NodeWithScopeId>,
}
impl SemanticIndex {
/// Returns the symbol table for a specific scope.
///
/// Use the Salsa cached [`symbol_table`] query if you only need the
/// symbol table for a single scope.
pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable> {
self.symbol_tables[scope_id].clone()
}
pub(crate) fn ast_ids(&self, scope_id: FileScopeId) -> &AstIds {
&self.ast_ids[scope_id]
}
/// Returns the ID of the `expression`'s enclosing scope.
#[allow(unused)]
pub(crate) fn expression_scope_id(&self, expression: &ast::Expr) -> FileScopeId {
self.expression_scopes[&NodeKey::from_node(expression)]
}
/// Returns the [`Scope`] of the `expression`'s enclosing scope.
#[allow(unused)]
pub(crate) fn expression_scope(&self, expression: &ast::Expr) -> &Scope {
&self.scopes[self.expression_scope_id(expression)]
}
/// Returns the [`Scope`] with the given id.
#[allow(unused)]
pub(crate) fn scope(&self, id: FileScopeId) -> &Scope {
&self.scopes[id]
}
/// Returns the id of the parent scope.
pub(crate) fn parent_scope_id(&self, scope_id: FileScopeId) -> Option<FileScopeId> {
let scope = self.scope(scope_id);
scope.parent
}
/// Returns the parent scope of `scope_id`.
#[allow(unused)]
pub(crate) fn parent_scope(&self, scope_id: FileScopeId) -> Option<&Scope> {
Some(&self.scopes[self.parent_scope_id(scope_id)?])
}
/// Returns an iterator over the descendent scopes of `scope`.
#[allow(unused)]
pub(crate) fn descendent_scopes(&self, scope: FileScopeId) -> DescendentsIter {
DescendentsIter::new(self, scope)
}
/// Returns an iterator over the direct child scopes of `scope`.
#[allow(unused)]
pub(crate) fn child_scopes(&self, scope: FileScopeId) -> ChildrenIter {
ChildrenIter::new(self, scope)
}
/// Returns an iterator over all ancestors of `scope`, starting with `scope` itself.
#[allow(unused)]
pub(crate) fn ancestor_scopes(&self, scope: FileScopeId) -> AncestorsIter {
AncestorsIter::new(self, scope)
}
pub(crate) fn scope_node(&self, scope_id: FileScopeId) -> NodeWithScopeId {
self.scope_nodes[scope_id]
}
}
/// ID that uniquely identifies an expression inside a [`Scope`].
pub struct AncestorsIter<'a> {
scopes: &'a IndexSlice<FileScopeId, Scope>,
next_id: Option<FileScopeId>,
}
impl<'a> AncestorsIter<'a> {
fn new(module_symbol_table: &'a SemanticIndex, start: FileScopeId) -> Self {
Self {
scopes: &module_symbol_table.scopes,
next_id: Some(start),
}
}
}
impl<'a> Iterator for AncestorsIter<'a> {
type Item = (FileScopeId, &'a Scope);
fn next(&mut self) -> Option<Self::Item> {
let current_id = self.next_id?;
let current = &self.scopes[current_id];
self.next_id = current.parent;
Some((current_id, current))
}
}
impl FusedIterator for AncestorsIter<'_> {}
pub struct DescendentsIter<'a> {
next_id: FileScopeId,
descendents: std::slice::Iter<'a, Scope>,
}
impl<'a> DescendentsIter<'a> {
fn new(symbol_table: &'a SemanticIndex, scope_id: FileScopeId) -> Self {
let scope = &symbol_table.scopes[scope_id];
let scopes = &symbol_table.scopes[scope.descendents.clone()];
Self {
next_id: scope_id + 1,
descendents: scopes.iter(),
}
}
}
impl<'a> Iterator for DescendentsIter<'a> {
type Item = (FileScopeId, &'a Scope);
fn next(&mut self) -> Option<Self::Item> {
let descendent = self.descendents.next()?;
let id = self.next_id;
self.next_id = self.next_id + 1;
Some((id, descendent))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.descendents.size_hint()
}
}
impl FusedIterator for DescendentsIter<'_> {}
impl ExactSizeIterator for DescendentsIter<'_> {}
pub struct ChildrenIter<'a> {
parent: FileScopeId,
descendents: DescendentsIter<'a>,
}
impl<'a> ChildrenIter<'a> {
fn new(module_symbol_table: &'a SemanticIndex, parent: FileScopeId) -> Self {
let descendents = DescendentsIter::new(module_symbol_table, parent);
Self {
parent,
descendents,
}
}
}
impl<'a> Iterator for ChildrenIter<'a> {
type Item = (FileScopeId, &'a Scope);
fn next(&mut self) -> Option<Self::Item> {
self.descendents
.find(|(_, scope)| scope.parent == Some(self.parent))
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub(crate) enum NodeWithScopeId {
Module,
Class(AstId<ScopeClassId>),
ClassTypeParams(AstId<ScopeClassId>),
Function(AstId<ScopeFunctionId>),
FunctionTypeParams(AstId<ScopeFunctionId>),
}
impl NodeWithScopeId {
fn scope_kind(self) -> ScopeKind {
match self {
NodeWithScopeId::Module => ScopeKind::Module,
NodeWithScopeId::Class(_) => ScopeKind::Class,
NodeWithScopeId::Function(_) => ScopeKind::Function,
NodeWithScopeId::ClassTypeParams(_) | NodeWithScopeId::FunctionTypeParams(_) => {
ScopeKind::Annotation
}
}
}
}
impl FusedIterator for ChildrenIter<'_> {}
#[cfg(test)]
mod tests {
use ruff_db::parsed::parsed_module;
use ruff_db::vfs::{system_path_to_file, VfsFile};
use crate::db::tests::TestDb;
use crate::semantic_index::symbol::{FileScopeId, ScopeKind, SymbolTable};
use crate::semantic_index::{root_scope, semantic_index, symbol_table};
struct TestCase {
db: TestDb,
file: VfsFile,
}
fn test_case(content: impl ToString) -> TestCase {
let db = TestDb::new();
db.memory_file_system()
.write_file("test.py", content)
.unwrap();
let file = system_path_to_file(&db, "test.py").unwrap();
TestCase { db, file }
}
fn names(table: &SymbolTable) -> Vec<&str> {
table
.symbols()
.map(|symbol| symbol.name().as_str())
.collect()
}
#[test]
fn empty() {
let TestCase { db, file } = test_case("");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), Vec::<&str>::new());
}
#[test]
fn simple() {
let TestCase { db, file } = test_case("x");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["x"]);
}
#[test]
fn annotation_only() {
let TestCase { db, file } = test_case("x: int");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["int", "x"]);
// TODO record definition
}
#[test]
fn import() {
let TestCase { db, file } = test_case("import foo");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["foo"]);
let foo = root_table.symbol_by_name("foo").unwrap();
assert_eq!(foo.definitions().len(), 1);
}
#[test]
fn import_sub() {
let TestCase { db, file } = test_case("import foo.bar");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["foo"]);
}
#[test]
fn import_as() {
let TestCase { db, file } = test_case("import foo.bar as baz");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["baz"]);
}
#[test]
fn import_from() {
let TestCase { db, file } = test_case("from bar import foo");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["foo"]);
assert_eq!(
root_table
.symbol_by_name("foo")
.unwrap()
.definitions()
.len(),
1
);
assert!(
root_table
.symbol_by_name("foo")
.is_some_and(|symbol| { symbol.is_defined() || !symbol.is_used() }),
"symbols that are defined get the defined flag"
);
}
#[test]
fn assign() {
let TestCase { db, file } = test_case("x = foo");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["foo", "x"]);
assert_eq!(
root_table.symbol_by_name("x").unwrap().definitions().len(),
1
);
assert!(
root_table
.symbol_by_name("foo")
.is_some_and(|symbol| { !symbol.is_defined() && symbol.is_used() }),
"a symbol used but not defined in a scope should have only the used flag"
);
}
#[test]
fn class_scope() {
let TestCase { db, file } = test_case(
"
class C:
x = 1
y = 2
",
);
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["C", "y"]);
let index = semantic_index(&db, file);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
assert_eq!(scopes.len(), 1);
let (class_scope_id, class_scope) = scopes[0];
assert_eq!(class_scope.kind(), ScopeKind::Class);
assert_eq!(class_scope.name(), "C");
let class_table = index.symbol_table(class_scope_id);
assert_eq!(names(&class_table), vec!["x"]);
assert_eq!(
class_table.symbol_by_name("x").unwrap().definitions().len(),
1
);
}
#[test]
fn function_scope() {
let TestCase { db, file } = test_case(
"
def func():
x = 1
y = 2
",
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
assert_eq!(names(&root_table), vec!["func", "y"]);
let scopes = index.child_scopes(FileScopeId::root()).collect::<Vec<_>>();
assert_eq!(scopes.len(), 1);
let (function_scope_id, function_scope) = scopes[0];
assert_eq!(function_scope.kind(), ScopeKind::Function);
assert_eq!(function_scope.name(), "func");
let function_table = index.symbol_table(function_scope_id);
assert_eq!(names(&function_table), vec!["x"]);
assert_eq!(
function_table
.symbol_by_name("x")
.unwrap()
.definitions()
.len(),
1
);
}
#[test]
fn dupes() {
let TestCase { db, file } = test_case(
"
def func():
x = 1
def func():
y = 2
",
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
assert_eq!(names(&root_table), vec!["func"]);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
assert_eq!(scopes.len(), 2);
let (func_scope1_id, func_scope_1) = scopes[0];
let (func_scope2_id, func_scope_2) = scopes[1];
assert_eq!(func_scope_1.kind(), ScopeKind::Function);
assert_eq!(func_scope_1.name(), "func");
assert_eq!(func_scope_2.kind(), ScopeKind::Function);
assert_eq!(func_scope_2.name(), "func");
let func1_table = index.symbol_table(func_scope1_id);
let func2_table = index.symbol_table(func_scope2_id);
assert_eq!(names(&func1_table), vec!["x"]);
assert_eq!(names(&func2_table), vec!["y"]);
assert_eq!(
root_table
.symbol_by_name("func")
.unwrap()
.definitions()
.len(),
2
);
}
#[test]
fn generic_function() {
let TestCase { db, file } = test_case(
"
def func[T]():
x = 1
",
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
assert_eq!(names(&root_table), vec!["func"]);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
assert_eq!(scopes.len(), 1);
let (ann_scope_id, ann_scope) = scopes[0];
assert_eq!(ann_scope.kind(), ScopeKind::Annotation);
assert_eq!(ann_scope.name(), "func");
let ann_table = index.symbol_table(ann_scope_id);
assert_eq!(names(&ann_table), vec!["T"]);
let scopes: Vec<_> = index.child_scopes(ann_scope_id).collect();
assert_eq!(scopes.len(), 1);
let (func_scope_id, func_scope) = scopes[0];
assert_eq!(func_scope.kind(), ScopeKind::Function);
assert_eq!(func_scope.name(), "func");
let func_table = index.symbol_table(func_scope_id);
assert_eq!(names(&func_table), vec!["x"]);
}
#[test]
fn generic_class() {
let TestCase { db, file } = test_case(
"
class C[T]:
x = 1
",
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
assert_eq!(names(&root_table), vec!["C"]);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
assert_eq!(scopes.len(), 1);
let (ann_scope_id, ann_scope) = scopes[0];
assert_eq!(ann_scope.kind(), ScopeKind::Annotation);
assert_eq!(ann_scope.name(), "C");
let ann_table = index.symbol_table(ann_scope_id);
assert_eq!(names(&ann_table), vec!["T"]);
assert!(
ann_table
.symbol_by_name("T")
.is_some_and(|s| s.is_defined() && !s.is_used()),
"type parameters are defined by the scope that introduces them"
);
let scopes: Vec<_> = index.child_scopes(ann_scope_id).collect();
assert_eq!(scopes.len(), 1);
let (func_scope_id, func_scope) = scopes[0];
assert_eq!(func_scope.kind(), ScopeKind::Class);
assert_eq!(func_scope.name(), "C");
assert_eq!(names(&index.symbol_table(func_scope_id)), vec!["x"]);
}
// TODO: After porting the control flow graph.
// #[test]
// fn reachability_trivial() {
// let parsed = parse("x = 1; x");
// let ast = parsed.syntax();
// let index = SemanticIndex::from_ast(ast);
// let table = &index.symbol_table;
// let x_sym = table
// .root_symbol_id_by_name("x")
// .expect("x symbol should exist");
// let ast::Stmt::Expr(ast::StmtExpr { value: x_use, .. }) = &ast.body[1] else {
// panic!("should be an expr")
// };
// let x_defs: Vec<_> = index
// .reachable_definitions(x_sym, x_use)
// .map(|constrained_definition| constrained_definition.definition)
// .collect();
// assert_eq!(x_defs.len(), 1);
// let Definition::Assignment(node_key) = &x_defs[0] else {
// panic!("def should be an assignment")
// };
// let Some(def_node) = node_key.resolve(ast.into()) else {
// panic!("node key should resolve")
// };
// let ast::Expr::NumberLiteral(ast::ExprNumberLiteral {
// value: ast::Number::Int(num),
// ..
// }) = &*def_node.value
// else {
// panic!("should be a number literal")
// };
// assert_eq!(*num, 1);
// }
#[test]
fn expression_scope() {
let TestCase { db, file } = test_case("x = 1;\ndef test():\n y = 4");
let index = semantic_index(&db, file);
let parsed = parsed_module(&db, file);
let ast = parsed.syntax();
let x_stmt = ast.body[0].as_assign_stmt().unwrap();
let x = &x_stmt.targets[0];
assert_eq!(index.expression_scope(x).kind(), ScopeKind::Module);
assert_eq!(index.expression_scope_id(x), FileScopeId::root());
let def = ast.body[1].as_function_def_stmt().unwrap();
let y_stmt = def.body[0].as_assign_stmt().unwrap();
let y = &y_stmt.targets[0];
assert_eq!(index.expression_scope(y).kind(), ScopeKind::Function);
}
#[test]
fn scope_iterators() {
let TestCase { db, file } = test_case(
r#"
class Test:
def foo():
def bar():
...
def baz():
pass
def x():
pass"#,
);
let index = semantic_index(&db, file);
let descendents: Vec<_> = index
.descendent_scopes(FileScopeId::root())
.map(|(_, scope)| scope.name().as_str())
.collect();
assert_eq!(descendents, vec!["Test", "foo", "bar", "baz", "x"]);
let children: Vec<_> = index
.child_scopes(FileScopeId::root())
.map(|(_, scope)| scope.name.as_str())
.collect();
assert_eq!(children, vec!["Test", "x"]);
let test_class = index.child_scopes(FileScopeId::root()).next().unwrap().0;
let test_child_scopes: Vec<_> = index
.child_scopes(test_class)
.map(|(_, scope)| scope.name.as_str())
.collect();
assert_eq!(test_child_scopes, vec!["foo", "baz"]);
let bar_scope = index
.descendent_scopes(FileScopeId::root())
.nth(2)
.unwrap()
.0;
let ancestors: Vec<_> = index
.ancestor_scopes(bar_scope)
.map(|(_, scope)| scope.name())
.collect();
assert_eq!(ancestors, vec!["bar", "foo", "Test", "<module>"]);
}
}

View file

@ -0,0 +1,393 @@
use rustc_hash::FxHashMap;
use ruff_db::parsed::ParsedModule;
use ruff_db::vfs::VfsFile;
use ruff_index::{newtype_index, IndexVec};
use ruff_python_ast as ast;
use ruff_python_ast::AnyNodeRef;
use crate::ast_node_ref::AstNodeRef;
use crate::node_key::NodeKey;
use crate::semantic_index::semantic_index;
use crate::semantic_index::symbol::{FileScopeId, ScopeId};
use crate::Db;
/// AST ids for a single scope.
///
/// The motivation for building the AST ids per scope isn't about reducing invalidation because
/// the struct changes whenever the parsed AST changes. Instead, it's mainly that we can
/// build the AST ids struct when building the symbol table and also keep the property that
/// IDs of outer scopes are unaffected by changes in inner scopes.
///
/// For example, we don't want that adding new statements to `foo` changes the statement id of `x = foo()` in:
///
/// ```python
/// def foo():
/// return 5
///
/// x = foo()
/// ```
pub(crate) struct AstIds {
/// Maps expression ids to their expressions.
expressions: IndexVec<ScopeExpressionId, AstNodeRef<ast::Expr>>,
/// Maps expressions to their expression id. Uses `NodeKey` because it avoids cloning [`Parsed`].
expressions_map: FxHashMap<NodeKey, ScopeExpressionId>,
statements: IndexVec<ScopeStatementId, AstNodeRef<ast::Stmt>>,
statements_map: FxHashMap<NodeKey, ScopeStatementId>,
}
impl AstIds {
fn statement_id<'a, N>(&self, node: N) -> ScopeStatementId
where
N: Into<AnyNodeRef<'a>>,
{
self.statements_map[&NodeKey::from_node(node.into())]
}
fn expression_id<'a, N>(&self, node: N) -> ScopeExpressionId
where
N: Into<AnyNodeRef<'a>>,
{
self.expressions_map[&NodeKey::from_node(node.into())]
}
}
#[allow(clippy::missing_fields_in_debug)]
impl std::fmt::Debug for AstIds {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AstIds")
.field("expressions", &self.expressions)
.field("statements", &self.statements)
.finish()
}
}
fn ast_ids(db: &dyn Db, scope: ScopeId) -> &AstIds {
semantic_index(db, scope.file(db)).ast_ids(scope.file_scope_id(db))
}
/// Node that can be uniquely identified by an id in a [`FileScopeId`].
pub trait ScopeAstIdNode {
/// The type of the ID uniquely identifying the node.
type Id: Copy;
/// Returns the ID that uniquely identifies the node in `scope`.
///
/// ## Panics
/// Panics if the node doesn't belong to `file` or is outside `scope`.
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> Self::Id;
/// Looks up the AST node by its ID.
///
/// ## Panics
/// May panic if the `id` does not belong to the AST of `file`, or is outside `scope`.
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self
where
Self: Sized;
}
/// Extension trait for AST nodes that can be resolved by an `AstId`.
pub trait AstIdNode {
type ScopeId: Copy;
/// Resolves the AST id of the node.
///
/// ## Panics
/// May panic if the node does not belongs to `file`'s AST or is outside of `scope`. It may also
/// return an incorrect node if that's the case.
fn ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> AstId<Self::ScopeId>;
/// Resolves the AST node for `id`.
///
/// ## Panics
/// May panic if the `id` does not belong to the AST of `file` or it returns an incorrect node.
fn lookup(db: &dyn Db, file: VfsFile, id: AstId<Self::ScopeId>) -> &Self
where
Self: Sized;
}
impl<T> AstIdNode for T
where
T: ScopeAstIdNode,
{
type ScopeId = T::Id;
fn ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> AstId<Self::ScopeId> {
let in_scope_id = self.scope_ast_id(db, file, scope);
AstId { scope, in_scope_id }
}
fn lookup(db: &dyn Db, file: VfsFile, id: AstId<Self::ScopeId>) -> &Self
where
Self: Sized,
{
let scope = id.scope;
Self::lookup_in_scope(db, file, scope, id.in_scope_id)
}
}
/// Uniquely identifies an AST node in a file.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct AstId<L: Copy> {
/// The node's scope.
scope: FileScopeId,
/// The ID of the node inside [`Self::scope`].
in_scope_id: L,
}
impl<L: Copy> AstId<L> {
pub(super) fn new(scope: FileScopeId, in_scope_id: L) -> Self {
Self { scope, in_scope_id }
}
pub(super) fn in_scope_id(self) -> L {
self.in_scope_id
}
}
/// Uniquely identifies an [`ast::Expr`] in a [`FileScopeId`].
#[newtype_index]
pub struct ScopeExpressionId;
impl ScopeAstIdNode for ast::Expr {
type Id = ScopeExpressionId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ast_ids.expressions_map[&NodeKey::from_node(self)]
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, file_scope: FileScopeId, id: Self::Id) -> &Self {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ast_ids.expressions[id].node()
}
}
/// Uniquely identifies an [`ast::Stmt`] in a [`FileScopeId`].
#[newtype_index]
pub struct ScopeStatementId;
impl ScopeAstIdNode for ast::Stmt {
type Id = ScopeStatementId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ast_ids.statement_id(self)
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, file_scope: FileScopeId, id: Self::Id) -> &Self {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ast_ids.statements[id].node()
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub struct ScopeFunctionId(pub(super) ScopeStatementId);
impl ScopeAstIdNode for ast::StmtFunctionDef {
type Id = ScopeFunctionId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ScopeFunctionId(ast_ids.statement_id(self))
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
ast::Stmt::lookup_in_scope(db, file, scope, id.0)
.as_function_def_stmt()
.unwrap()
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub struct ScopeClassId(pub(super) ScopeStatementId);
impl ScopeAstIdNode for ast::StmtClassDef {
type Id = ScopeClassId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ScopeClassId(ast_ids.statement_id(self))
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
statement.as_class_def_stmt().unwrap()
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub struct ScopeAssignmentId(pub(super) ScopeStatementId);
impl ScopeAstIdNode for ast::StmtAssign {
type Id = ScopeAssignmentId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ScopeAssignmentId(ast_ids.statement_id(self))
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
statement.as_assign_stmt().unwrap()
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub struct ScopeAnnotatedAssignmentId(ScopeStatementId);
impl ScopeAstIdNode for ast::StmtAnnAssign {
type Id = ScopeAnnotatedAssignmentId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ScopeAnnotatedAssignmentId(ast_ids.statement_id(self))
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
statement.as_ann_assign_stmt().unwrap()
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub struct ScopeImportId(pub(super) ScopeStatementId);
impl ScopeAstIdNode for ast::StmtImport {
type Id = ScopeImportId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ScopeImportId(ast_ids.statement_id(self))
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
statement.as_import_stmt().unwrap()
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub struct ScopeImportFromId(pub(super) ScopeStatementId);
impl ScopeAstIdNode for ast::StmtImportFrom {
type Id = ScopeImportFromId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ScopeImportFromId(ast_ids.statement_id(self))
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
statement.as_import_from_stmt().unwrap()
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub struct ScopeNamedExprId(pub(super) ScopeExpressionId);
impl ScopeAstIdNode for ast::ExprNamed {
type Id = ScopeNamedExprId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ScopeNamedExprId(ast_ids.expression_id(self))
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self
where
Self: Sized,
{
let expression = ast::Expr::lookup_in_scope(db, file, scope, id.0);
expression.as_named_expr().unwrap()
}
}
#[derive(Debug)]
pub(super) struct AstIdsBuilder {
expressions: IndexVec<ScopeExpressionId, AstNodeRef<ast::Expr>>,
expressions_map: FxHashMap<NodeKey, ScopeExpressionId>,
statements: IndexVec<ScopeStatementId, AstNodeRef<ast::Stmt>>,
statements_map: FxHashMap<NodeKey, ScopeStatementId>,
}
impl AstIdsBuilder {
pub(super) fn new() -> Self {
Self {
expressions: IndexVec::default(),
expressions_map: FxHashMap::default(),
statements: IndexVec::default(),
statements_map: FxHashMap::default(),
}
}
/// Adds `stmt` to the AST ids map and returns its id.
///
/// ## Safety
/// The function is marked as unsafe because it calls [`AstNodeRef::new`] which requires
/// that `stmt` is a child of `parsed`.
#[allow(unsafe_code)]
pub(super) unsafe fn record_statement(
&mut self,
stmt: &ast::Stmt,
parsed: &ParsedModule,
) -> ScopeStatementId {
let statement_id = self.statements.push(AstNodeRef::new(parsed.clone(), stmt));
self.statements_map
.insert(NodeKey::from_node(stmt), statement_id);
statement_id
}
/// Adds `expr` to the AST ids map and returns its id.
///
/// ## Safety
/// The function is marked as unsafe because it calls [`AstNodeRef::new`] which requires
/// that `expr` is a child of `parsed`.
#[allow(unsafe_code)]
pub(super) unsafe fn record_expression(
&mut self,
expr: &ast::Expr,
parsed: &ParsedModule,
) -> ScopeExpressionId {
let expression_id = self.expressions.push(AstNodeRef::new(parsed.clone(), expr));
self.expressions_map
.insert(NodeKey::from_node(expr), expression_id);
expression_id
}
pub(super) fn finish(mut self) -> AstIds {
self.expressions.shrink_to_fit();
self.expressions_map.shrink_to_fit();
self.statements.shrink_to_fit();
self.statements_map.shrink_to_fit();
AstIds {
expressions: self.expressions,
expressions_map: self.expressions_map,
statements: self.statements,
statements_map: self.statements_map,
}
}
}

View file

@ -0,0 +1,454 @@
use std::sync::Arc;
use rustc_hash::FxHashMap;
use ruff_db::parsed::ParsedModule;
use ruff_index::IndexVec;
use ruff_python_ast as ast;
use ruff_python_ast::visitor::{walk_expr, walk_stmt, Visitor};
use crate::name::Name;
use crate::node_key::NodeKey;
use crate::semantic_index::ast_ids::{
AstId, AstIdsBuilder, ScopeAssignmentId, ScopeClassId, ScopeFunctionId, ScopeImportFromId,
ScopeImportId, ScopeNamedExprId,
};
use crate::semantic_index::definition::{Definition, ImportDefinition, ImportFromDefinition};
use crate::semantic_index::symbol::{
FileScopeId, FileSymbolId, Scope, ScopedSymbolId, SymbolFlags, SymbolTableBuilder,
};
use crate::semantic_index::{NodeWithScopeId, SemanticIndex};
pub(super) struct SemanticIndexBuilder<'a> {
// Builder state
module: &'a ParsedModule,
scope_stack: Vec<FileScopeId>,
/// the definition whose target(s) we are currently walking
current_definition: Option<Definition>,
// Semantic Index fields
scopes: IndexVec<FileScopeId, Scope>,
symbol_tables: IndexVec<FileScopeId, SymbolTableBuilder>,
ast_ids: IndexVec<FileScopeId, AstIdsBuilder>,
expression_scopes: FxHashMap<NodeKey, FileScopeId>,
scope_nodes: IndexVec<FileScopeId, NodeWithScopeId>,
}
impl<'a> SemanticIndexBuilder<'a> {
pub(super) fn new(parsed: &'a ParsedModule) -> Self {
let mut builder = Self {
module: parsed,
scope_stack: Vec::new(),
current_definition: None,
scopes: IndexVec::new(),
symbol_tables: IndexVec::new(),
ast_ids: IndexVec::new(),
expression_scopes: FxHashMap::default(),
scope_nodes: IndexVec::new(),
};
builder.push_scope_with_parent(
NodeWithScopeId::Module,
&Name::new_static("<module>"),
None,
None,
None,
);
builder
}
fn current_scope(&self) -> FileScopeId {
*self
.scope_stack
.last()
.expect("Always to have a root scope")
}
fn push_scope(
&mut self,
node: NodeWithScopeId,
name: &Name,
defining_symbol: Option<FileSymbolId>,
definition: Option<Definition>,
) {
let parent = self.current_scope();
self.push_scope_with_parent(node, name, defining_symbol, definition, Some(parent));
}
fn push_scope_with_parent(
&mut self,
node: NodeWithScopeId,
name: &Name,
defining_symbol: Option<FileSymbolId>,
definition: Option<Definition>,
parent: Option<FileScopeId>,
) {
let children_start = self.scopes.next_index() + 1;
let scope = Scope {
name: name.clone(),
parent,
defining_symbol,
definition,
kind: node.scope_kind(),
descendents: children_start..children_start,
};
let scope_id = self.scopes.push(scope);
self.symbol_tables.push(SymbolTableBuilder::new());
let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new());
let scope_node_id = self.scope_nodes.push(node);
debug_assert_eq!(ast_id_scope, scope_id);
debug_assert_eq!(scope_id, scope_node_id);
self.scope_stack.push(scope_id);
}
fn pop_scope(&mut self) -> FileScopeId {
let id = self.scope_stack.pop().expect("Root scope to be present");
let children_end = self.scopes.next_index();
let scope = &mut self.scopes[id];
scope.descendents = scope.descendents.start..children_end;
id
}
fn current_symbol_table(&mut self) -> &mut SymbolTableBuilder {
let scope_id = self.current_scope();
&mut self.symbol_tables[scope_id]
}
fn current_ast_ids(&mut self) -> &mut AstIdsBuilder {
let scope_id = self.current_scope();
&mut self.ast_ids[scope_id]
}
fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId {
let symbol_table = self.current_symbol_table();
symbol_table.add_or_update_symbol(name, flags, None)
}
fn add_or_update_symbol_with_definition(
&mut self,
name: Name,
definition: Definition,
) -> ScopedSymbolId {
let symbol_table = self.current_symbol_table();
symbol_table.add_or_update_symbol(name, SymbolFlags::IS_DEFINED, Some(definition))
}
fn with_type_params(
&mut self,
name: &Name,
with_params: &WithTypeParams,
defining_symbol: FileSymbolId,
nested: impl FnOnce(&mut Self) -> FileScopeId,
) -> FileScopeId {
let type_params = with_params.type_parameters();
if let Some(type_params) = type_params {
let type_node = match with_params {
WithTypeParams::ClassDef { id, .. } => NodeWithScopeId::ClassTypeParams(*id),
WithTypeParams::FunctionDef { id, .. } => NodeWithScopeId::FunctionTypeParams(*id),
};
self.push_scope(
type_node,
name,
Some(defining_symbol),
Some(with_params.definition()),
);
for type_param in &type_params.type_params {
let name = match type_param {
ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name,
ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, .. }) => name,
ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, .. }) => name,
};
self.add_or_update_symbol(Name::new(name), SymbolFlags::IS_DEFINED);
}
}
let nested_scope = nested(self);
if type_params.is_some() {
self.pop_scope();
}
nested_scope
}
pub(super) fn build(mut self) -> SemanticIndex {
let module = self.module;
self.visit_body(module.suite());
// Pop the root scope
self.pop_scope();
assert!(self.scope_stack.is_empty());
assert!(self.current_definition.is_none());
let mut symbol_tables: IndexVec<_, _> = self
.symbol_tables
.into_iter()
.map(|builder| Arc::new(builder.finish()))
.collect();
let mut ast_ids: IndexVec<_, _> = self
.ast_ids
.into_iter()
.map(super::ast_ids::AstIdsBuilder::finish)
.collect();
self.scopes.shrink_to_fit();
ast_ids.shrink_to_fit();
symbol_tables.shrink_to_fit();
self.expression_scopes.shrink_to_fit();
self.scope_nodes.shrink_to_fit();
SemanticIndex {
symbol_tables,
scopes: self.scopes,
scope_nodes: self.scope_nodes,
ast_ids,
expression_scopes: self.expression_scopes,
}
}
}
impl Visitor<'_> for SemanticIndexBuilder<'_> {
fn visit_stmt(&mut self, stmt: &ast::Stmt) {
let module = self.module;
#[allow(unsafe_code)]
let statement_id = unsafe {
// SAFETY: The builder only visits nodes that are part of `module`. This guarantees that
// the current statement must be a child of `module`.
self.current_ast_ids().record_statement(stmt, module)
};
match stmt {
ast::Stmt::FunctionDef(function_def) => {
for decorator in &function_def.decorator_list {
self.visit_decorator(decorator);
}
let name = Name::new(&function_def.name.id);
let function_id = ScopeFunctionId(statement_id);
let definition = Definition::FunctionDef(function_id);
let scope = self.current_scope();
let symbol = FileSymbolId::new(
scope,
self.add_or_update_symbol_with_definition(name.clone(), definition),
);
self.with_type_params(
&name,
&WithTypeParams::FunctionDef {
node: function_def,
id: AstId::new(scope, function_id),
},
symbol,
|builder| {
builder.visit_parameters(&function_def.parameters);
for expr in &function_def.returns {
builder.visit_annotation(expr);
}
builder.push_scope(
NodeWithScopeId::Function(AstId::new(scope, function_id)),
&name,
Some(symbol),
Some(definition),
);
builder.visit_body(&function_def.body);
builder.pop_scope()
},
);
}
ast::Stmt::ClassDef(class) => {
for decorator in &class.decorator_list {
self.visit_decorator(decorator);
}
let name = Name::new(&class.name.id);
let class_id = ScopeClassId(statement_id);
let definition = Definition::from(class_id);
let scope = self.current_scope();
let id = FileSymbolId::new(
self.current_scope(),
self.add_or_update_symbol_with_definition(name.clone(), definition),
);
self.with_type_params(
&name,
&WithTypeParams::ClassDef {
node: class,
id: AstId::new(scope, class_id),
},
id,
|builder| {
if let Some(arguments) = &class.arguments {
builder.visit_arguments(arguments);
}
builder.push_scope(
NodeWithScopeId::Class(AstId::new(scope, class_id)),
&name,
Some(id),
Some(definition),
);
builder.visit_body(&class.body);
builder.pop_scope()
},
);
}
ast::Stmt::Import(ast::StmtImport { names, .. }) => {
for (i, alias) in names.iter().enumerate() {
let symbol_name = if let Some(asname) = &alias.asname {
asname.id.as_str()
} else {
alias.name.id.split('.').next().unwrap()
};
let def = Definition::Import(ImportDefinition {
import_id: ScopeImportId(statement_id),
alias: u32::try_from(i).unwrap(),
});
self.add_or_update_symbol_with_definition(Name::new(symbol_name), def);
}
}
ast::Stmt::ImportFrom(ast::StmtImportFrom {
module: _,
names,
level: _,
..
}) => {
for (i, alias) in names.iter().enumerate() {
let symbol_name = if let Some(asname) = &alias.asname {
asname.id.as_str()
} else {
alias.name.id.as_str()
};
let def = Definition::ImportFrom(ImportFromDefinition {
import_id: ScopeImportFromId(statement_id),
name: u32::try_from(i).unwrap(),
});
self.add_or_update_symbol_with_definition(Name::new(symbol_name), def);
}
}
ast::Stmt::Assign(node) => {
debug_assert!(self.current_definition.is_none());
self.visit_expr(&node.value);
self.current_definition =
Some(Definition::Assignment(ScopeAssignmentId(statement_id)));
for target in &node.targets {
self.visit_expr(target);
}
self.current_definition = None;
}
_ => {
walk_stmt(self, stmt);
}
}
}
fn visit_expr(&mut self, expr: &'_ ast::Expr) {
let module = self.module;
#[allow(unsafe_code)]
let expression_id = unsafe {
// SAFETY: The builder only visits nodes that are part of `module`. This guarantees that
// the current expression must be a child of `module`.
self.current_ast_ids().record_expression(expr, module)
};
self.expression_scopes
.insert(NodeKey::from_node(expr), self.current_scope());
match expr {
ast::Expr::Name(ast::ExprName { id, ctx, .. }) => {
let flags = match ctx {
ast::ExprContext::Load => SymbolFlags::IS_USED,
ast::ExprContext::Store => SymbolFlags::IS_DEFINED,
ast::ExprContext::Del => SymbolFlags::IS_DEFINED,
ast::ExprContext::Invalid => SymbolFlags::empty(),
};
match self.current_definition {
Some(definition) if flags.contains(SymbolFlags::IS_DEFINED) => {
self.add_or_update_symbol_with_definition(Name::new(id), definition);
}
_ => {
self.add_or_update_symbol(Name::new(id), flags);
}
}
walk_expr(self, expr);
}
ast::Expr::Named(node) => {
debug_assert!(self.current_definition.is_none());
self.current_definition =
Some(Definition::NamedExpr(ScopeNamedExprId(expression_id)));
// TODO walrus in comprehensions is implicitly nonlocal
self.visit_expr(&node.target);
self.current_definition = None;
self.visit_expr(&node.value);
}
ast::Expr::If(ast::ExprIf {
body, test, orelse, ..
}) => {
// TODO detect statically known truthy or falsy test (via type inference, not naive
// AST inspection, so we can't simplify here, need to record test expression in CFG
// for later checking)
self.visit_expr(test);
// let if_branch = self.flow_graph_builder.add_branch(self.current_flow_node());
// self.set_current_flow_node(if_branch);
// self.insert_constraint(test);
self.visit_expr(body);
// let post_body = self.current_flow_node();
// self.set_current_flow_node(if_branch);
self.visit_expr(orelse);
// let post_else = self
// .flow_graph_builder
// .add_phi(self.current_flow_node(), post_body);
// self.set_current_flow_node(post_else);
}
_ => {
walk_expr(self, expr);
}
}
}
}
enum WithTypeParams<'a> {
ClassDef {
node: &'a ast::StmtClassDef,
id: AstId<ScopeClassId>,
},
FunctionDef {
node: &'a ast::StmtFunctionDef,
id: AstId<ScopeFunctionId>,
},
}
impl<'a> WithTypeParams<'a> {
fn type_parameters(&self) -> Option<&'a ast::TypeParams> {
match self {
WithTypeParams::ClassDef { node, .. } => node.type_params.as_deref(),
WithTypeParams::FunctionDef { node, .. } => node.type_params.as_deref(),
}
}
fn definition(&self) -> Definition {
match self {
WithTypeParams::ClassDef { id, .. } => Definition::ClassDef(id.in_scope_id()),
WithTypeParams::FunctionDef { id, .. } => Definition::FunctionDef(id.in_scope_id()),
}
}
}

View file

@ -0,0 +1,76 @@
use crate::semantic_index::ast_ids::{
ScopeAnnotatedAssignmentId, ScopeAssignmentId, ScopeClassId, ScopeFunctionId,
ScopeImportFromId, ScopeImportId, ScopeNamedExprId,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum Definition {
Import(ImportDefinition),
ImportFrom(ImportFromDefinition),
ClassDef(ScopeClassId),
FunctionDef(ScopeFunctionId),
Assignment(ScopeAssignmentId),
AnnotatedAssignment(ScopeAnnotatedAssignmentId),
NamedExpr(ScopeNamedExprId),
/// represents the implicit initial definition of every name as "unbound"
Unbound,
// TODO with statements, except handlers, function args...
}
impl From<ImportDefinition> for Definition {
fn from(value: ImportDefinition) -> Self {
Self::Import(value)
}
}
impl From<ImportFromDefinition> for Definition {
fn from(value: ImportFromDefinition) -> Self {
Self::ImportFrom(value)
}
}
impl From<ScopeClassId> for Definition {
fn from(value: ScopeClassId) -> Self {
Self::ClassDef(value)
}
}
impl From<ScopeFunctionId> for Definition {
fn from(value: ScopeFunctionId) -> Self {
Self::FunctionDef(value)
}
}
impl From<ScopeAssignmentId> for Definition {
fn from(value: ScopeAssignmentId) -> Self {
Self::Assignment(value)
}
}
impl From<ScopeAnnotatedAssignmentId> for Definition {
fn from(value: ScopeAnnotatedAssignmentId) -> Self {
Self::AnnotatedAssignment(value)
}
}
impl From<ScopeNamedExprId> for Definition {
fn from(value: ScopeNamedExprId) -> Self {
Self::NamedExpr(value)
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct ImportDefinition {
pub(crate) import_id: ScopeImportId,
/// Index into [`ruff_python_ast::StmtImport::names`].
pub(crate) alias: u32,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct ImportFromDefinition {
pub(crate) import_id: ScopeImportFromId,
/// Index into [`ruff_python_ast::StmtImportFrom::names`].
pub(crate) name: u32,
}

View file

@ -0,0 +1,379 @@
// Allow unused underscore violations generated by the salsa macro
// TODO(micha): Contribute fix upstream
#![allow(clippy::used_underscore_binding)]
use std::hash::{Hash, Hasher};
use std::ops::Range;
use bitflags::bitflags;
use hashbrown::hash_map::RawEntryMut;
use rustc_hash::FxHasher;
use smallvec::SmallVec;
use ruff_db::vfs::VfsFile;
use ruff_index::{newtype_index, IndexVec};
use crate::name::Name;
use crate::semantic_index::definition::Definition;
use crate::semantic_index::{root_scope, semantic_index, symbol_table, SymbolMap};
use crate::Db;
#[derive(Eq, PartialEq, Debug)]
pub struct Symbol {
name: Name,
flags: SymbolFlags,
/// The nodes that define this symbol, in source order.
definitions: SmallVec<[Definition; 4]>,
}
impl Symbol {
fn new(name: Name, definition: Option<Definition>) -> Self {
Self {
name,
flags: SymbolFlags::empty(),
definitions: definition.into_iter().collect(),
}
}
fn push_definition(&mut self, definition: Definition) {
self.definitions.push(definition);
}
fn insert_flags(&mut self, flags: SymbolFlags) {
self.flags.insert(flags);
}
/// The symbol's name.
pub fn name(&self) -> &Name {
&self.name
}
/// Is the symbol used in its containing scope?
pub fn is_used(&self) -> bool {
self.flags.contains(SymbolFlags::IS_USED)
}
/// Is the symbol defined in its containing scope?
pub fn is_defined(&self) -> bool {
self.flags.contains(SymbolFlags::IS_DEFINED)
}
pub fn definitions(&self) -> &[Definition] {
&self.definitions
}
}
bitflags! {
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub(super) struct SymbolFlags: u8 {
const IS_USED = 1 << 0;
const IS_DEFINED = 1 << 1;
/// TODO: This flag is not yet set by anything
const MARKED_GLOBAL = 1 << 2;
/// TODO: This flag is not yet set by anything
const MARKED_NONLOCAL = 1 << 3;
}
}
/// ID that uniquely identifies a public symbol defined in a module's root scope.
#[salsa::tracked]
pub struct PublicSymbolId {
#[id]
pub(crate) file: VfsFile,
#[id]
pub(crate) scoped_symbol_id: ScopedSymbolId,
}
/// ID that uniquely identifies a symbol in a file.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct FileSymbolId {
scope: FileScopeId,
scoped_symbol_id: ScopedSymbolId,
}
impl FileSymbolId {
pub(super) fn new(scope: FileScopeId, symbol: ScopedSymbolId) -> Self {
Self {
scope,
scoped_symbol_id: symbol,
}
}
pub fn scope(self) -> FileScopeId {
self.scope
}
pub(crate) fn scoped_symbol_id(self) -> ScopedSymbolId {
self.scoped_symbol_id
}
}
impl From<FileSymbolId> for ScopedSymbolId {
fn from(val: FileSymbolId) -> Self {
val.scoped_symbol_id()
}
}
/// Symbol ID that uniquely identifies a symbol inside a [`Scope`].
#[newtype_index]
pub struct ScopedSymbolId;
impl ScopedSymbolId {
/// Converts the symbol to a public symbol.
///
/// # Panics
/// May panic if the symbol does not belong to `file` or is not a symbol of `file`'s root scope.
pub(crate) fn to_public_symbol(self, db: &dyn Db, file: VfsFile) -> PublicSymbolId {
let symbols = public_symbols_map(db, file);
symbols.public(self)
}
}
/// Returns a mapping from [`FileScopeId`] to globally unique [`ScopeId`].
#[salsa::tracked(return_ref)]
pub(crate) fn scopes_map(db: &dyn Db, file: VfsFile) -> ScopesMap {
let index = semantic_index(db, file);
let scopes: IndexVec<_, _> = index
.scopes
.indices()
.map(|id| ScopeId::new(db, file, id))
.collect();
ScopesMap { scopes }
}
/// Maps from the file specific [`FileScopeId`] to the global [`ScopeId`] that can be used as a Salsa query parameter.
///
/// The [`SemanticIndex`] uses [`FileScopeId`] on a per-file level to identify scopes
/// because they allow for more efficient storage of associated data
/// (use of an [`IndexVec`] keyed by [`FileScopeId`] over an [`FxHashMap`] keyed by [`ScopeId`]).
#[derive(Eq, PartialEq, Debug)]
pub(crate) struct ScopesMap {
scopes: IndexVec<FileScopeId, ScopeId>,
}
impl ScopesMap {
/// Gets the program-wide unique scope id for the given file specific `scope_id`.
fn get(&self, scope: FileScopeId) -> ScopeId {
self.scopes[scope]
}
}
#[salsa::tracked(return_ref)]
pub(crate) fn public_symbols_map(db: &dyn Db, file: VfsFile) -> PublicSymbolsMap {
let module_scope = root_scope(db, file);
let symbols = symbol_table(db, module_scope);
let public_symbols: IndexVec<_, _> = symbols
.symbol_ids()
.map(|id| PublicSymbolId::new(db, file, id))
.collect();
PublicSymbolsMap {
symbols: public_symbols,
}
}
/// Maps [`LocalSymbolId`] of a file's root scope to the corresponding [`PublicSymbolId`] (Salsa ingredients).
#[derive(Eq, PartialEq, Debug)]
pub(crate) struct PublicSymbolsMap {
symbols: IndexVec<ScopedSymbolId, PublicSymbolId>,
}
impl PublicSymbolsMap {
/// Resolve the [`PublicSymbolId`] for the module-level `symbol_id`.
fn public(&self, symbol_id: ScopedSymbolId) -> PublicSymbolId {
self.symbols[symbol_id]
}
}
/// A cross-module identifier of a scope that can be used as a salsa query parameter.
#[salsa::tracked]
pub struct ScopeId {
#[allow(clippy::used_underscore_binding)]
#[id]
pub file: VfsFile,
#[id]
pub file_scope_id: FileScopeId,
}
/// ID that uniquely identifies a scope inside of a module.
#[newtype_index]
pub struct FileScopeId;
impl FileScopeId {
/// Returns the scope id of the Root scope.
pub fn root() -> Self {
FileScopeId::from_u32(0)
}
pub fn to_scope_id(self, db: &dyn Db, file: VfsFile) -> ScopeId {
scopes_map(db, file).get(self)
}
}
#[derive(Debug, Eq, PartialEq)]
pub struct Scope {
pub(super) name: Name,
pub(super) parent: Option<FileScopeId>,
pub(super) definition: Option<Definition>,
pub(super) defining_symbol: Option<FileSymbolId>,
pub(super) kind: ScopeKind,
pub(super) descendents: Range<FileScopeId>,
}
impl Scope {
pub fn name(&self) -> &Name {
&self.name
}
pub fn definition(&self) -> Option<Definition> {
self.definition
}
pub fn defining_symbol(&self) -> Option<FileSymbolId> {
self.defining_symbol
}
pub fn parent(self) -> Option<FileScopeId> {
self.parent
}
pub fn kind(&self) -> ScopeKind {
self.kind
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ScopeKind {
Module,
Annotation,
Class,
Function,
}
/// Symbol table for a specific [`Scope`].
#[derive(Debug)]
pub struct SymbolTable {
/// The symbols in this scope.
symbols: IndexVec<ScopedSymbolId, Symbol>,
/// The symbols indexed by name.
symbols_by_name: SymbolMap,
}
impl SymbolTable {
fn new() -> Self {
Self {
symbols: IndexVec::new(),
symbols_by_name: SymbolMap::default(),
}
}
fn shrink_to_fit(&mut self) {
self.symbols.shrink_to_fit();
}
pub(crate) fn symbol(&self, symbol_id: impl Into<ScopedSymbolId>) -> &Symbol {
&self.symbols[symbol_id.into()]
}
#[allow(unused)]
pub(crate) fn symbol_ids(&self) -> impl Iterator<Item = ScopedSymbolId> {
self.symbols.indices()
}
pub fn symbols(&self) -> impl Iterator<Item = &Symbol> {
self.symbols.iter()
}
/// Returns the symbol named `name`.
#[allow(unused)]
pub(crate) fn symbol_by_name(&self, name: &str) -> Option<&Symbol> {
let id = self.symbol_id_by_name(name)?;
Some(self.symbol(id))
}
/// Returns the [`ScopedSymbolId`] of the symbol named `name`.
pub(crate) fn symbol_id_by_name(&self, name: &str) -> Option<ScopedSymbolId> {
let (id, ()) = self
.symbols_by_name
.raw_entry()
.from_hash(Self::hash_name(name), |id| {
self.symbol(*id).name().as_str() == name
})?;
Some(*id)
}
fn hash_name(name: &str) -> u64 {
let mut hasher = FxHasher::default();
name.hash(&mut hasher);
hasher.finish()
}
}
impl PartialEq for SymbolTable {
fn eq(&self, other: &Self) -> bool {
// We don't need to compare the symbols_by_name because the name is already captured in `Symbol`.
self.symbols == other.symbols
}
}
impl Eq for SymbolTable {}
#[derive(Debug)]
pub(super) struct SymbolTableBuilder {
table: SymbolTable,
}
impl SymbolTableBuilder {
pub(super) fn new() -> Self {
Self {
table: SymbolTable::new(),
}
}
pub(super) fn add_or_update_symbol(
&mut self,
name: Name,
flags: SymbolFlags,
definition: Option<Definition>,
) -> ScopedSymbolId {
let hash = SymbolTable::hash_name(&name);
let entry = self
.table
.symbols_by_name
.raw_entry_mut()
.from_hash(hash, |id| self.table.symbols[*id].name() == &name);
match entry {
RawEntryMut::Occupied(entry) => {
let symbol = &mut self.table.symbols[*entry.key()];
symbol.insert_flags(flags);
if let Some(definition) = definition {
symbol.push_definition(definition);
}
*entry.key()
}
RawEntryMut::Vacant(entry) => {
let mut symbol = Symbol::new(name, definition);
symbol.insert_flags(flags);
let id = self.table.symbols.push(symbol);
entry.insert_with_hasher(hash, id, (), |id| {
SymbolTable::hash_name(self.table.symbols[*id].name().as_str())
});
id
}
}
}
pub(super) fn finish(mut self) -> SymbolTable {
self.table.shrink_to_fit();
self.table
}
}

View file

@ -0,0 +1,682 @@
use salsa::DebugWithDb;
use ruff_db::parsed::parsed_module;
use ruff_db::vfs::VfsFile;
use ruff_index::newtype_index;
use ruff_python_ast as ast;
use crate::name::Name;
use crate::semantic_index::ast_ids::{AstIdNode, ScopeAstIdNode};
use crate::semantic_index::symbol::{FileScopeId, PublicSymbolId, ScopeId};
use crate::semantic_index::{
public_symbol, root_scope, semantic_index, symbol_table, NodeWithScopeId,
};
use crate::types::infer::{TypeInference, TypeInferenceBuilder};
use crate::Db;
use crate::FxIndexSet;
mod display;
mod infer;
/// Infers the type of `expr`.
///
/// Calling this function from a salsa query adds a dependency on [`semantic_index`]
/// which changes with every AST change. That's why you should only call
/// this function for the current file that's being analyzed and not for
/// a dependency (or the query reruns whenever a dependency change).
///
/// Prefer [`public_symbol_ty`] when resolving the type of symbol from another file.
#[tracing::instrument(level = "debug", skip(db))]
pub(crate) fn expression_ty(db: &dyn Db, file: VfsFile, expression: &ast::Expr) -> Type {
let index = semantic_index(db, file);
let file_scope = index.expression_scope_id(expression);
let expression_id = expression.scope_ast_id(db, file, file_scope);
let scope = file_scope.to_scope_id(db, file);
infer_types(db, scope).expression_ty(expression_id)
}
/// Infers the type of a public symbol.
///
/// This is a Salsa query to get symbol-level invalidation instead of file-level dependency invalidation.
/// Without this being a query, changing any public type of a module would invalidate the type inference
/// for the module scope of its dependents and the transitive dependents because.
///
/// For example if we have
/// ```python
/// # a.py
/// import x from b
///
/// # b.py
///
/// x = 20
/// ```
///
/// And x is now changed from `x = 20` to `x = 30`. The following happens:
///
/// * The module level types of `b.py` change because `x` now is a `Literal[30]`.
/// * The module level types of `a.py` change because the imported symbol `x` now has a `Literal[30]` type
/// * The module level types of any dependents of `a.py` change because the imported symbol `x` now has a `Literal[30]` type
/// * And so on for all transitive dependencies.
///
/// This being a query ensures that the invalidation short-circuits if the type of this symbol didn't change.
#[salsa::tracked]
pub(crate) fn public_symbol_ty(db: &dyn Db, symbol: PublicSymbolId) -> Type {
let _ = tracing::debug_span!("public_symbol_ty", symbol = ?symbol.debug(db)).enter();
let file = symbol.file(db);
let scope = root_scope(db, file);
let inference = infer_types(db, scope);
inference.symbol_ty(symbol.scoped_symbol_id(db))
}
/// Shorthand for `public_symbol_ty` that takes a symbol name instead of a [`PublicSymbolId`].
pub fn public_symbol_ty_by_name(db: &dyn Db, file: VfsFile, name: &str) -> Option<Type> {
let symbol = public_symbol(db, file, name)?;
Some(public_symbol_ty(db, symbol))
}
/// Infers all types for `scope`.
#[salsa::tracked(return_ref)]
pub(crate) fn infer_types(db: &dyn Db, scope: ScopeId) -> TypeInference {
let file = scope.file(db);
// Using the index here is fine because the code below depends on the AST anyway.
// The isolation of the query is by the return inferred types.
let index = semantic_index(db, file);
let scope_id = scope.file_scope_id(db);
let node = index.scope_node(scope_id);
let mut context = TypeInferenceBuilder::new(db, scope, index);
match node {
NodeWithScopeId::Module => {
let parsed = parsed_module(db.upcast(), file);
context.infer_module(parsed.syntax());
}
NodeWithScopeId::Class(class_id) => {
let class = ast::StmtClassDef::lookup(db, file, class_id);
context.infer_class_body(class);
}
NodeWithScopeId::ClassTypeParams(class_id) => {
let class = ast::StmtClassDef::lookup(db, file, class_id);
context.infer_class_type_params(class);
}
NodeWithScopeId::Function(function_id) => {
let function = ast::StmtFunctionDef::lookup(db, file, function_id);
context.infer_function_body(function);
}
NodeWithScopeId::FunctionTypeParams(function_id) => {
let function = ast::StmtFunctionDef::lookup(db, file, function_id);
context.infer_function_type_params(function);
}
}
context.finish()
}
/// unique ID for a type
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum Type {
/// the dynamic type: a statically-unknown set of values
Any,
/// the empty set of values
Never,
/// unknown type (no annotation)
/// equivalent to Any, or to object in strict mode
Unknown,
/// name is not bound to any value
Unbound,
/// the None object (TODO remove this in favor of Instance(types.NoneType)
None,
/// a specific function object
Function(TypeId<ScopedFunctionTypeId>),
/// a specific module object
Module(TypeId<ScopedModuleTypeId>),
/// a specific class object
Class(TypeId<ScopedClassTypeId>),
/// the set of Python objects with the given class in their __class__'s method resolution order
Instance(TypeId<ScopedClassTypeId>),
Union(TypeId<ScopedUnionTypeId>),
Intersection(TypeId<ScopedIntersectionTypeId>),
IntLiteral(i64),
// TODO protocols, callable types, overloads, generics, type vars
}
impl Type {
pub const fn is_unbound(&self) -> bool {
matches!(self, Type::Unbound)
}
pub const fn is_unknown(&self) -> bool {
matches!(self, Type::Unknown)
}
pub fn member(&self, context: &TypingContext, name: &Name) -> Option<Type> {
match self {
Type::Any => Some(Type::Any),
Type::Never => todo!("attribute lookup on Never type"),
Type::Unknown => Some(Type::Unknown),
Type::Unbound => todo!("attribute lookup on Unbound type"),
Type::None => todo!("attribute lookup on None type"),
Type::Function(_) => todo!("attribute lookup on Function type"),
Type::Module(module) => module.member(context, name),
Type::Class(class) => class.class_member(context, name),
Type::Instance(_) => {
// TODO MRO? get_own_instance_member, get_instance_member
todo!("attribute lookup on Instance type")
}
Type::Union(union_id) => {
let _union = union_id.lookup(context);
// TODO perform the get_member on each type in the union
// TODO return the union of those results
// TODO if any of those results is `None` then include Unknown in the result union
todo!("attribute lookup on Union type")
}
Type::Intersection(_) => {
// TODO perform the get_member on each type in the intersection
// TODO return the intersection of those results
todo!("attribute lookup on Intersection type")
}
Type::IntLiteral(_) => {
// TODO raise error
Some(Type::Unknown)
}
}
}
}
/// ID that uniquely identifies a type in a program.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct TypeId<L> {
/// The scope in which this type is defined or was created.
scope: ScopeId,
/// The type's local ID in its scope.
scoped: L,
}
impl<Id> TypeId<Id>
where
Id: Copy,
{
pub fn scope(&self) -> ScopeId {
self.scope
}
pub fn scoped_id(&self) -> Id {
self.scoped
}
/// Resolves the type ID to the actual type.
pub(crate) fn lookup<'a>(self, context: &'a TypingContext) -> &'a Id::Ty
where
Id: ScopedTypeId,
{
let types = context.types(self.scope);
self.scoped.lookup_scoped(types)
}
}
/// ID that uniquely identifies a type in a scope.
pub(crate) trait ScopedTypeId {
/// The type that this ID points to.
type Ty;
/// Looks up the type in `index`.
///
/// ## Panics
/// May panic if this type is from another scope than `index`, or might just return an invalid type.
fn lookup_scoped(self, index: &TypeInference) -> &Self::Ty;
}
/// ID uniquely identifying a function type in a `scope`.
#[newtype_index]
pub struct ScopedFunctionTypeId;
impl ScopedTypeId for ScopedFunctionTypeId {
type Ty = FunctionType;
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
types.function_ty(self)
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct FunctionType {
/// name of the function at definition
name: Name,
/// types of all decorators on this function
decorators: Vec<Type>,
}
impl FunctionType {
fn name(&self) -> &str {
self.name.as_str()
}
#[allow(unused)]
pub(crate) fn decorators(&self) -> &[Type] {
self.decorators.as_slice()
}
}
#[newtype_index]
pub struct ScopedClassTypeId;
impl ScopedTypeId for ScopedClassTypeId {
type Ty = ClassType;
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
types.class_ty(self)
}
}
impl TypeId<ScopedClassTypeId> {
/// Returns the class member of this class named `name`.
///
/// The member resolves to a member of the class itself or any of its bases.
fn class_member(self, context: &TypingContext, name: &Name) -> Option<Type> {
if let Some(member) = self.own_class_member(context, name) {
return Some(member);
}
let class = self.lookup(context);
for base in &class.bases {
if let Some(member) = base.member(context, name) {
return Some(member);
}
}
None
}
/// Returns the inferred type of the class member named `name`.
fn own_class_member(self, context: &TypingContext, name: &Name) -> Option<Type> {
let class = self.lookup(context);
let symbols = symbol_table(context.db, class.body_scope);
let symbol = symbols.symbol_id_by_name(name)?;
let types = context.types(class.body_scope);
Some(types.symbol_ty(symbol))
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct ClassType {
/// Name of the class at definition
name: Name,
/// Types of all class bases
bases: Vec<Type>,
body_scope: ScopeId,
}
impl ClassType {
fn name(&self) -> &str {
self.name.as_str()
}
#[allow(unused)]
pub(super) fn bases(&self) -> &[Type] {
self.bases.as_slice()
}
}
#[newtype_index]
pub struct ScopedUnionTypeId;
impl ScopedTypeId for ScopedUnionTypeId {
type Ty = UnionType;
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
types.union_ty(self)
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct UnionType {
// the union type includes values in any of these types
elements: FxIndexSet<Type>,
}
struct UnionTypeBuilder<'a> {
elements: FxIndexSet<Type>,
context: &'a TypingContext<'a>,
}
impl<'a> UnionTypeBuilder<'a> {
fn new(context: &'a TypingContext<'a>) -> Self {
Self {
context,
elements: FxIndexSet::default(),
}
}
/// Adds a type to this union.
fn add(mut self, ty: Type) -> Self {
match ty {
Type::Union(union_id) => {
let union = union_id.lookup(self.context);
self.elements.extend(&union.elements);
}
_ => {
self.elements.insert(ty);
}
}
self
}
fn build(self) -> UnionType {
UnionType {
elements: self.elements,
}
}
}
#[newtype_index]
pub struct ScopedIntersectionTypeId;
impl ScopedTypeId for ScopedIntersectionTypeId {
type Ty = IntersectionType;
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
types.intersection_ty(self)
}
}
// Negation types aren't expressible in annotations, and are most likely to arise from type
// narrowing along with intersections (e.g. `if not isinstance(...)`), so we represent them
// directly in intersections rather than as a separate type. This sacrifices some efficiency in the
// case where a Not appears outside an intersection (unclear when that could even happen, but we'd
// have to represent it as a single-element intersection if it did) in exchange for better
// efficiency in the within-intersection case.
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct IntersectionType {
// the intersection type includes only values in all of these types
positive: FxIndexSet<Type>,
// the intersection type does not include any value in any of these types
negative: FxIndexSet<Type>,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct ScopedModuleTypeId;
impl ScopedTypeId for ScopedModuleTypeId {
type Ty = ModuleType;
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
types.module_ty()
}
}
impl TypeId<ScopedModuleTypeId> {
fn member(self, context: &TypingContext, name: &Name) -> Option<Type> {
context.public_symbol_ty(self.scope.file(context.db), name)
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct ModuleType {
file: VfsFile,
}
/// Context in which to resolve types.
///
/// This abstraction is necessary to support a uniform API that can be used
/// while in the process of building the type inference structure for a scope
/// but also when all types should be resolved by querying the db.
pub struct TypingContext<'a> {
db: &'a dyn Db,
/// The Local type inference scope that is in the process of being built.
///
/// Bypass the `db` when resolving the types for this scope.
local: Option<(ScopeId, &'a TypeInference)>,
}
impl<'a> TypingContext<'a> {
/// Creates a context that resolves all types by querying the db.
#[allow(unused)]
pub(super) fn global(db: &'a dyn Db) -> Self {
Self { db, local: None }
}
/// Creates a context that by-passes the `db` when resolving types from `scope_id` and instead uses `types`.
fn scoped(db: &'a dyn Db, scope_id: ScopeId, types: &'a TypeInference) -> Self {
Self {
db,
local: Some((scope_id, types)),
}
}
/// Returns the [`TypeInference`] results (not guaranteed to be complete) for `scope_id`.
fn types(&self, scope_id: ScopeId) -> &'a TypeInference {
if let Some((scope, local_types)) = self.local {
if scope == scope_id {
return local_types;
}
}
infer_types(self.db, scope_id)
}
fn module_ty(&self, file: VfsFile) -> Type {
let scope = root_scope(self.db, file);
Type::Module(TypeId {
scope,
scoped: ScopedModuleTypeId,
})
}
/// Resolves the public type of a symbol named `name` defined in `file`.
///
/// This function calls [`public_symbol_ty`] if the local scope isn't the module scope of `file`.
/// It otherwise tries to resolve the symbol type locally.
fn public_symbol_ty(&self, file: VfsFile, name: &Name) -> Option<Type> {
let symbol = public_symbol(self.db, file, name)?;
if let Some((scope, local_types)) = self.local {
if scope.file_scope_id(self.db) == FileScopeId::root() && scope.file(self.db) == file {
return Some(local_types.symbol_ty(symbol.scoped_symbol_id(self.db)));
}
}
Some(public_symbol_ty(self.db, symbol))
}
}
#[cfg(test)]
mod tests {
use ruff_db::file_system::FileSystemPathBuf;
use ruff_db::parsed::parsed_module;
use ruff_db::vfs::system_path_to_file;
use crate::db::tests::{
assert_will_not_run_function_query, assert_will_run_function_query, TestDb,
};
use crate::module::resolver::{set_module_resolution_settings, ModuleResolutionSettings};
use crate::semantic_index::root_scope;
use crate::types::{expression_ty, infer_types, public_symbol_ty_by_name, TypingContext};
fn setup_db() -> TestDb {
let mut db = TestDb::new();
set_module_resolution_settings(
&mut db,
ModuleResolutionSettings {
extra_paths: vec![],
workspace_root: FileSystemPathBuf::from("/src"),
site_packages: None,
custom_typeshed: None,
},
);
db
}
#[test]
fn local_inference() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_file("/src/a.py", "x = 10")?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let parsed = parsed_module(&db, a);
let statement = parsed.suite().first().unwrap().as_assign_stmt().unwrap();
let literal_ty = expression_ty(&db, a, &statement.value);
assert_eq!(
format!("{}", literal_ty.display(&TypingContext::global(&db))),
"Literal[10]"
);
Ok(())
}
#[test]
fn dependency_public_symbol_type_change() -> anyhow::Result<()> {
let mut db = setup_db();
db.memory_file_system().write_files([
("/src/a.py", "from foo import x"),
("/src/foo.py", "x = 10\ndef foo(): ..."),
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
// Change `x` to a different value
db.memory_file_system()
.write_file("/src/foo.py", "x = 20\ndef foo(): ...")?;
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
foo.touch(&mut db);
let a = system_path_to_file(&db, "/src/a.py").unwrap();
db.clear_salsa_events();
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty_2.display(&TypingContext::global(&db)).to_string(),
"Literal[20]"
);
let events = db.take_salsa_events();
let a_root_scope = root_scope(&db, a);
assert_will_run_function_query::<infer_types, _, _>(
&db,
|ty| &ty.function,
a_root_scope,
&events,
);
Ok(())
}
#[test]
fn dependency_non_public_symbol_change() -> anyhow::Result<()> {
let mut db = setup_db();
db.memory_file_system().write_files([
("/src/a.py", "from foo import x"),
("/src/foo.py", "x = 10\ndef foo(): y = 1"),
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
db.memory_file_system()
.write_file("/src/foo.py", "x = 10\ndef foo(): pass")?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
foo.touch(&mut db);
db.clear_salsa_events();
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty_2.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
let events = db.take_salsa_events();
let a_root_scope = root_scope(&db, a);
assert_will_not_run_function_query::<infer_types, _, _>(
&db,
|ty| &ty.function,
a_root_scope,
&events,
);
Ok(())
}
#[test]
fn dependency_unrelated_public_symbol() -> anyhow::Result<()> {
let mut db = setup_db();
db.memory_file_system().write_files([
("/src/a.py", "from foo import x"),
("/src/foo.py", "x = 10\ny = 20"),
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
db.memory_file_system()
.write_file("/src/foo.py", "x = 10\ny = 30")?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
foo.touch(&mut db);
db.clear_salsa_events();
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty_2.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
let events = db.take_salsa_events();
let a_root_scope = root_scope(&db, a);
assert_will_not_run_function_query::<infer_types, _, _>(
&db,
|ty| &ty.function,
a_root_scope,
&events,
);
Ok(())
}
}

View file

@ -0,0 +1,175 @@
//! Display implementations for types.
use std::fmt::{Display, Formatter};
use crate::types::{IntersectionType, Type, TypingContext, UnionType};
impl Type {
pub fn display<'a>(&'a self, context: &'a TypingContext) -> DisplayType<'a> {
DisplayType { ty: self, context }
}
}
#[derive(Copy, Clone)]
pub struct DisplayType<'a> {
ty: &'a Type,
context: &'a TypingContext<'a>,
}
impl Display for DisplayType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self.ty {
Type::Any => f.write_str("Any"),
Type::Never => f.write_str("Never"),
Type::Unknown => f.write_str("Unknown"),
Type::Unbound => f.write_str("Unbound"),
Type::None => f.write_str("None"),
Type::Module(module_id) => {
write!(
f,
"<module '{:?}'>",
module_id
.scope
.file(self.context.db)
.path(self.context.db.upcast())
)
}
// TODO functions and classes should display using a fully qualified name
Type::Class(class_id) => {
let class = class_id.lookup(self.context);
f.write_str("Literal[")?;
f.write_str(class.name())?;
f.write_str("]")
}
Type::Instance(class_id) => {
let class = class_id.lookup(self.context);
f.write_str(class.name())
}
Type::Function(function_id) => {
let function = function_id.lookup(self.context);
f.write_str(function.name())
}
Type::Union(union_id) => {
let union = union_id.lookup(self.context);
union.display(self.context).fmt(f)
}
Type::Intersection(intersection_id) => {
let intersection = intersection_id.lookup(self.context);
intersection.display(self.context).fmt(f)
}
Type::IntLiteral(n) => write!(f, "Literal[{n}]"),
}
}
}
impl std::fmt::Debug for DisplayType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
}
impl UnionType {
fn display<'a>(&'a self, context: &'a TypingContext<'a>) -> DisplayUnionType<'a> {
DisplayUnionType { context, ty: self }
}
}
struct DisplayUnionType<'a> {
ty: &'a UnionType,
context: &'a TypingContext<'a>,
}
impl Display for DisplayUnionType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let union = self.ty;
let (int_literals, other_types): (Vec<Type>, Vec<Type>) = union
.elements
.iter()
.copied()
.partition(|ty| matches!(ty, Type::IntLiteral(_)));
let mut first = true;
if !int_literals.is_empty() {
f.write_str("Literal[")?;
let mut nums: Vec<_> = int_literals
.into_iter()
.filter_map(|ty| {
if let Type::IntLiteral(n) = ty {
Some(n)
} else {
None
}
})
.collect();
nums.sort_unstable();
for num in nums {
if !first {
f.write_str(", ")?;
}
write!(f, "{num}")?;
first = false;
}
f.write_str("]")?;
}
for ty in other_types {
if !first {
f.write_str(" | ")?;
};
first = false;
write!(f, "{}", ty.display(self.context))?;
}
Ok(())
}
}
impl std::fmt::Debug for DisplayUnionType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
}
impl IntersectionType {
fn display<'a>(&'a self, context: &'a TypingContext<'a>) -> DisplayIntersectionType<'a> {
DisplayIntersectionType { ty: self, context }
}
}
struct DisplayIntersectionType<'a> {
ty: &'a IntersectionType,
context: &'a TypingContext<'a>,
}
impl Display for DisplayIntersectionType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut first = true;
for (neg, ty) in self
.ty
.positive
.iter()
.map(|ty| (false, ty))
.chain(self.ty.negative.iter().map(|ty| (true, ty)))
{
if !first {
f.write_str(" & ")?;
};
first = false;
if neg {
f.write_str("~")?;
};
write!(f, "{}", ty.display(self.context))?;
}
Ok(())
}
}
impl std::fmt::Debug for DisplayIntersectionType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
}

View file

@ -0,0 +1,941 @@
use std::sync::Arc;
use rustc_hash::FxHashMap;
use ruff_db::vfs::VfsFile;
use ruff_index::IndexVec;
use ruff_python_ast as ast;
use ruff_python_ast::{ExprContext, TypeParams};
use crate::module::resolver::resolve_module;
use crate::module::ModuleName;
use crate::name::Name;
use crate::semantic_index::ast_ids::{ScopeAstIdNode, ScopeExpressionId};
use crate::semantic_index::definition::{Definition, ImportDefinition, ImportFromDefinition};
use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable};
use crate::semantic_index::{symbol_table, ChildrenIter, SemanticIndex};
use crate::types::{
ClassType, FunctionType, IntersectionType, ModuleType, ScopedClassTypeId, ScopedFunctionTypeId,
ScopedIntersectionTypeId, ScopedUnionTypeId, Type, TypeId, TypingContext, UnionType,
UnionTypeBuilder,
};
use crate::Db;
/// The inferred types for a single scope.
#[derive(Debug, Eq, PartialEq, Default, Clone)]
pub(crate) struct TypeInference {
/// The type of the module if the scope is a module scope.
module_type: Option<ModuleType>,
/// The types of the defined classes in this scope.
class_types: IndexVec<ScopedClassTypeId, ClassType>,
/// The types of the defined functions in this scope.
function_types: IndexVec<ScopedFunctionTypeId, FunctionType>,
union_types: IndexVec<ScopedUnionTypeId, UnionType>,
intersection_types: IndexVec<ScopedIntersectionTypeId, IntersectionType>,
/// The types of every expression in this scope.
expression_tys: IndexVec<ScopeExpressionId, Type>,
/// The public types of every symbol in this scope.
symbol_tys: IndexVec<ScopedSymbolId, Type>,
}
impl TypeInference {
#[allow(unused)]
pub(super) fn expression_ty(&self, expression: ScopeExpressionId) -> Type {
self.expression_tys[expression]
}
pub(super) fn symbol_ty(&self, symbol: ScopedSymbolId) -> Type {
self.symbol_tys[symbol]
}
pub(super) fn module_ty(&self) -> &ModuleType {
self.module_type.as_ref().unwrap()
}
pub(super) fn class_ty(&self, id: ScopedClassTypeId) -> &ClassType {
&self.class_types[id]
}
pub(super) fn function_ty(&self, id: ScopedFunctionTypeId) -> &FunctionType {
&self.function_types[id]
}
pub(super) fn union_ty(&self, id: ScopedUnionTypeId) -> &UnionType {
&self.union_types[id]
}
pub(super) fn intersection_ty(&self, id: ScopedIntersectionTypeId) -> &IntersectionType {
&self.intersection_types[id]
}
fn shrink_to_fit(&mut self) {
self.class_types.shrink_to_fit();
self.function_types.shrink_to_fit();
self.union_types.shrink_to_fit();
self.intersection_types.shrink_to_fit();
self.expression_tys.shrink_to_fit();
self.symbol_tys.shrink_to_fit();
}
}
/// Builder to infer all types in a [`ScopeId`].
pub(super) struct TypeInferenceBuilder<'a> {
db: &'a dyn Db,
// Cached lookups
index: &'a SemanticIndex,
scope: ScopeId,
file_scope_id: FileScopeId,
file_id: VfsFile,
symbol_table: Arc<SymbolTable>,
/// The type inference results
types: TypeInference,
definition_tys: FxHashMap<Definition, Type>,
children_scopes: ChildrenIter<'a>,
}
impl<'a> TypeInferenceBuilder<'a> {
/// Creates a new builder for inferring the types of `scope`.
pub(super) fn new(db: &'a dyn Db, scope: ScopeId, index: &'a SemanticIndex) -> Self {
let file_scope_id = scope.file_scope_id(db);
let file = scope.file(db);
let children_scopes = index.child_scopes(file_scope_id);
let symbol_table = index.symbol_table(file_scope_id);
Self {
index,
file_scope_id,
file_id: file,
scope,
symbol_table,
db,
types: TypeInference::default(),
definition_tys: FxHashMap::default(),
children_scopes,
}
}
/// Infers the types of a `module`.
pub(super) fn infer_module(&mut self, module: &ast::ModModule) {
self.infer_body(&module.body);
}
pub(super) fn infer_class_type_params(&mut self, class: &ast::StmtClassDef) {
if let Some(type_params) = class.type_params.as_deref() {
self.infer_type_parameters(type_params);
}
}
pub(super) fn infer_class_body(&mut self, class: &ast::StmtClassDef) {
self.infer_body(&class.body);
}
pub(super) fn infer_function_type_params(&mut self, function: &ast::StmtFunctionDef) {
if let Some(type_params) = function.type_params.as_deref() {
self.infer_type_parameters(type_params);
}
}
pub(super) fn infer_function_body(&mut self, function: &ast::StmtFunctionDef) {
self.infer_body(&function.body);
}
fn infer_body(&mut self, suite: &[ast::Stmt]) {
for statement in suite {
self.infer_statement(statement);
}
}
fn infer_statement(&mut self, statement: &ast::Stmt) {
match statement {
ast::Stmt::FunctionDef(function) => self.infer_function_definition_statement(function),
ast::Stmt::ClassDef(class) => self.infer_class_definition_statement(class),
ast::Stmt::Expr(ast::StmtExpr { range: _, value }) => {
self.infer_expression(value);
}
ast::Stmt::If(if_statement) => self.infer_if_statement(if_statement),
ast::Stmt::Assign(assign) => self.infer_assignment_statement(assign),
ast::Stmt::AnnAssign(assign) => self.infer_annotated_assignment_statement(assign),
ast::Stmt::For(for_statement) => self.infer_for_statement(for_statement),
ast::Stmt::Import(import) => self.infer_import_statement(import),
ast::Stmt::ImportFrom(import) => self.infer_import_from_statement(import),
ast::Stmt::Break(_) | ast::Stmt::Continue(_) | ast::Stmt::Pass(_) => {
// No-op
}
_ => {}
}
}
fn infer_function_definition_statement(&mut self, function: &ast::StmtFunctionDef) {
let ast::StmtFunctionDef {
range: _,
is_async: _,
name,
type_params: _,
parameters: _,
returns,
body: _,
decorator_list,
} = function;
let function_id = function.scope_ast_id(self.db, self.file_id, self.file_scope_id);
let decorator_tys = decorator_list
.iter()
.map(|decorator| self.infer_decorator(decorator))
.collect();
// TODO: Infer parameters
if let Some(return_ty) = returns {
self.infer_expression(return_ty);
}
let function_ty = self.function_ty(FunctionType {
name: Name::new(&name.id),
decorators: decorator_tys,
});
// Skip over the function or type params child scope.
let (_, scope) = self.children_scopes.next().unwrap();
assert!(matches!(
scope.kind(),
ScopeKind::Function | ScopeKind::Annotation
));
self.definition_tys
.insert(Definition::FunctionDef(function_id), function_ty);
}
fn infer_class_definition_statement(&mut self, class: &ast::StmtClassDef) {
let ast::StmtClassDef {
range: _,
name,
type_params,
decorator_list,
arguments,
body: _,
} = class;
let class_id = class.scope_ast_id(self.db, self.file_id, self.file_scope_id);
for decorator in decorator_list {
self.infer_decorator(decorator);
}
let bases = arguments
.as_deref()
.map(|arguments| self.infer_arguments(arguments))
.unwrap_or(Vec::new());
// If the class has type parameters, then the class body scope is the first child scope of the type parameter's scope
// Otherwise the next scope must be the class definition scope.
let (class_body_scope_id, class_body_scope) = if type_params.is_some() {
let (type_params_scope, _) = self.children_scopes.next().unwrap();
self.index.child_scopes(type_params_scope).next().unwrap()
} else {
self.children_scopes.next().unwrap()
};
assert_eq!(class_body_scope.kind(), ScopeKind::Class);
let class_ty = self.class_ty(ClassType {
name: Name::new(name),
bases,
body_scope: class_body_scope_id.to_scope_id(self.db, self.file_id),
});
self.definition_tys
.insert(Definition::ClassDef(class_id), class_ty);
}
fn infer_if_statement(&mut self, if_statement: &ast::StmtIf) {
let ast::StmtIf {
range: _,
test,
body,
elif_else_clauses,
} = if_statement;
self.infer_expression(test);
self.infer_body(body);
for clause in elif_else_clauses {
let ast::ElifElseClause {
range: _,
test,
body,
} = clause;
if let Some(test) = &test {
self.infer_expression(test);
}
self.infer_body(body);
}
}
fn infer_assignment_statement(&mut self, assignment: &ast::StmtAssign) {
let ast::StmtAssign {
range: _,
targets,
value,
} = assignment;
let value_ty = self.infer_expression(value);
for target in targets {
self.infer_expression(target);
}
let assign_id = assignment.scope_ast_id(self.db, self.file_id, self.file_scope_id);
// TODO: Handle multiple targets.
self.definition_tys
.insert(Definition::Assignment(assign_id), value_ty);
}
fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) {
let ast::StmtAnnAssign {
range: _,
target,
annotation,
value,
simple: _,
} = assignment;
if let Some(value) = value {
let _ = self.infer_expression(value);
}
let annotation_ty = self.infer_expression(annotation);
self.infer_expression(target);
self.definition_tys.insert(
Definition::AnnotatedAssignment(assignment.scope_ast_id(
self.db,
self.file_id,
self.file_scope_id,
)),
annotation_ty,
);
}
fn infer_for_statement(&mut self, for_statement: &ast::StmtFor) {
let ast::StmtFor {
range: _,
target,
iter,
body,
orelse,
is_async: _,
} = for_statement;
self.infer_expression(iter);
self.infer_expression(target);
self.infer_body(body);
self.infer_body(orelse);
}
fn infer_import_statement(&mut self, import: &ast::StmtImport) {
let ast::StmtImport { range: _, names } = import;
let import_id = import.scope_ast_id(self.db, self.file_id, self.file_scope_id);
for (i, alias) in names.iter().enumerate() {
let ast::Alias {
range: _,
name,
asname: _,
} = alias;
let module_name = ModuleName::new(&name.id);
let module = module_name.and_then(|name| resolve_module(self.db, name));
let module_ty = module
.map(|module| self.typing_context().module_ty(module.file()))
.unwrap_or(Type::Unknown);
self.definition_tys.insert(
Definition::Import(ImportDefinition {
import_id,
alias: u32::try_from(i).unwrap(),
}),
module_ty,
);
}
}
fn infer_import_from_statement(&mut self, import: &ast::StmtImportFrom) {
let ast::StmtImportFrom {
range: _,
module,
names,
level: _,
} = import;
let import_id = import.scope_ast_id(self.db, self.file_id, self.file_scope_id);
let module_name = ModuleName::new(module.as_deref().expect("Support relative imports"));
let module = module_name.and_then(|module_name| resolve_module(self.db, module_name));
let module_ty = module
.map(|module| self.typing_context().module_ty(module.file()))
.unwrap_or(Type::Unknown);
for (i, alias) in names.iter().enumerate() {
let ast::Alias {
range: _,
name,
asname: _,
} = alias;
let ty = module_ty
.member(&self.typing_context(), &Name::new(&name.id))
.unwrap_or(Type::Unknown);
self.definition_tys.insert(
Definition::ImportFrom(ImportFromDefinition {
import_id,
name: u32::try_from(i).unwrap(),
}),
ty,
);
}
}
fn infer_decorator(&mut self, decorator: &ast::Decorator) -> Type {
let ast::Decorator {
range: _,
expression,
} = decorator;
self.infer_expression(expression)
}
fn infer_arguments(&mut self, arguments: &ast::Arguments) -> Vec<Type> {
let mut types = Vec::with_capacity(
arguments
.args
.len()
.saturating_add(arguments.keywords.len()),
);
types.extend(arguments.args.iter().map(|arg| self.infer_expression(arg)));
types.extend(arguments.keywords.iter().map(
|ast::Keyword {
range: _,
arg: _,
value,
}| self.infer_expression(value),
));
types
}
fn infer_expression(&mut self, expression: &ast::Expr) -> Type {
let ty = match expression {
ast::Expr::NoneLiteral(ast::ExprNoneLiteral { range: _ }) => Type::None,
ast::Expr::NumberLiteral(literal) => self.infer_number_literal_expression(literal),
ast::Expr::Name(name) => self.infer_name_expression(name),
ast::Expr::Attribute(attribute) => self.infer_attribute_expression(attribute),
ast::Expr::BinOp(binary) => self.infer_binary_expression(binary),
ast::Expr::Named(named) => self.infer_named_expression(named),
ast::Expr::If(if_expression) => self.infer_if_expression(if_expression),
_ => todo!("expression type resolution for {:?}", expression),
};
self.types.expression_tys.push(ty);
ty
}
#[allow(clippy::unused_self)]
fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type {
let ast::ExprNumberLiteral { range: _, value } = literal;
match value {
ast::Number::Int(n) => {
// TODO support big int literals
n.as_i64().map(Type::IntLiteral).unwrap_or(Type::Unknown)
}
// TODO builtins.float or builtins.complex
_ => Type::Unknown,
}
}
fn infer_named_expression(&mut self, named: &ast::ExprNamed) -> Type {
let ast::ExprNamed {
range: _,
target,
value,
} = named;
let value_ty = self.infer_expression(value);
self.infer_expression(target);
self.definition_tys.insert(
Definition::NamedExpr(named.scope_ast_id(self.db, self.file_id, self.file_scope_id)),
value_ty,
);
value_ty
}
fn infer_if_expression(&mut self, if_expression: &ast::ExprIf) -> Type {
let ast::ExprIf {
range: _,
test,
body,
orelse,
} = if_expression;
self.infer_expression(test);
// TODO detect statically known truthy or falsy test
let body_ty = self.infer_expression(body);
let orelse_ty = self.infer_expression(orelse);
let union = UnionTypeBuilder::new(&self.typing_context())
.add(body_ty)
.add(orelse_ty)
.build();
self.union_ty(union)
}
fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type {
let ast::ExprName { range: _, id, ctx } = name;
match ctx {
ExprContext::Load => {
if let Some(symbol_id) = self
.index
.symbol_table(self.file_scope_id)
.symbol_id_by_name(id)
{
self.local_definition_ty(symbol_id)
} else {
let ancestors = self.index.ancestor_scopes(self.file_scope_id).skip(1);
for (ancestor_id, _) in ancestors {
// TODO: Skip over class scopes unless the they are a immediately-nested type param scope.
// TODO: Support built-ins
let symbol_table =
symbol_table(self.db, ancestor_id.to_scope_id(self.db, self.file_id));
if let Some(_symbol_id) = symbol_table.symbol_id_by_name(id) {
todo!("Return type for symbol from outer scope");
}
}
Type::Unknown
}
}
ExprContext::Del => Type::None,
ExprContext::Invalid => Type::Unknown,
ExprContext::Store => Type::None,
}
}
fn infer_attribute_expression(&mut self, attribute: &ast::ExprAttribute) -> Type {
let ast::ExprAttribute {
value,
attr,
range: _,
ctx,
} = attribute;
let value_ty = self.infer_expression(value);
let member_ty = value_ty
.member(&self.typing_context(), &Name::new(&attr.id))
.unwrap_or(Type::Unknown);
match ctx {
ExprContext::Load => member_ty,
ExprContext::Store | ExprContext::Del => Type::None,
ExprContext::Invalid => Type::Unknown,
}
}
fn infer_binary_expression(&mut self, binary: &ast::ExprBinOp) -> Type {
let ast::ExprBinOp {
left,
op,
right,
range: _,
} = binary;
let left_ty = self.infer_expression(left);
let right_ty = self.infer_expression(right);
match left_ty {
Type::Any => Type::Any,
Type::Unknown => Type::Unknown,
Type::IntLiteral(n) => {
match right_ty {
Type::IntLiteral(m) => {
match op {
ast::Operator::Add => n
.checked_add(m)
.map(Type::IntLiteral)
// TODO builtins.int
.unwrap_or(Type::Unknown),
ast::Operator::Sub => n
.checked_sub(m)
.map(Type::IntLiteral)
// TODO builtins.int
.unwrap_or(Type::Unknown),
ast::Operator::Mult => n
.checked_mul(m)
.map(Type::IntLiteral)
// TODO builtins.int
.unwrap_or(Type::Unknown),
ast::Operator::Div => n
.checked_div(m)
.map(Type::IntLiteral)
// TODO builtins.int
.unwrap_or(Type::Unknown),
ast::Operator::Mod => n
.checked_rem(m)
.map(Type::IntLiteral)
// TODO division by zero error
.unwrap_or(Type::Unknown),
_ => todo!("complete binop op support for IntLiteral"),
}
}
_ => todo!("complete binop right_ty support for IntLiteral"),
}
}
_ => todo!("complete binop support"),
}
}
fn infer_type_parameters(&mut self, _type_parameters: &TypeParams) {
todo!("Infer type parameters")
}
pub(super) fn finish(mut self) -> TypeInference {
let symbol_tys: IndexVec<_, _> = self
.index
.symbol_table(self.file_scope_id)
.symbol_ids()
.map(|symbol| self.local_definition_ty(symbol))
.collect();
self.types.symbol_tys = symbol_tys;
self.types.shrink_to_fit();
self.types
}
fn union_ty(&mut self, ty: UnionType) -> Type {
Type::Union(TypeId {
scope: self.scope,
scoped: self.types.union_types.push(ty),
})
}
fn function_ty(&mut self, ty: FunctionType) -> Type {
Type::Function(TypeId {
scope: self.scope,
scoped: self.types.function_types.push(ty),
})
}
fn class_ty(&mut self, ty: ClassType) -> Type {
Type::Class(TypeId {
scope: self.scope,
scoped: self.types.class_types.push(ty),
})
}
fn typing_context(&self) -> TypingContext {
TypingContext::scoped(self.db, self.scope, &self.types)
}
fn local_definition_ty(&mut self, symbol: ScopedSymbolId) -> Type {
let symbol = self.symbol_table.symbol(symbol);
let mut definitions = symbol
.definitions()
.iter()
.filter_map(|definition| self.definition_tys.get(definition).copied());
let Some(first) = definitions.next() else {
return Type::Unbound;
};
if let Some(second) = definitions.next() {
let context = self.typing_context();
let mut builder = UnionTypeBuilder::new(&context);
builder = builder.add(first).add(second);
for variant in definitions {
builder = builder.add(variant);
}
self.union_ty(builder.build())
} else {
first
}
}
}
#[cfg(test)]
mod tests {
use ruff_db::file_system::FileSystemPathBuf;
use ruff_db::vfs::system_path_to_file;
use crate::db::tests::TestDb;
use crate::module::resolver::{set_module_resolution_settings, ModuleResolutionSettings};
use crate::name::Name;
use crate::types::{public_symbol_ty_by_name, Type, TypingContext};
fn setup_db() -> TestDb {
let mut db = TestDb::new();
set_module_resolution_settings(
&mut db,
ModuleResolutionSettings {
extra_paths: Vec::new(),
workspace_root: FileSystemPathBuf::from("/src"),
site_packages: None,
custom_typeshed: None,
},
);
db
}
fn assert_public_ty(db: &TestDb, file_name: &str, symbol_name: &str, expected: &str) {
let file = system_path_to_file(db, file_name).expect("Expected file to exist.");
let ty = public_symbol_ty_by_name(db, file, symbol_name).unwrap_or(Type::Unknown);
assert_eq!(ty.display(&TypingContext::global(db)).to_string(), expected);
}
#[test]
fn follow_import_to_class() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_files([
("src/a.py", "from b import C as D; E = D"),
("src/b.py", "class C: pass"),
])?;
assert_public_ty(&db, "src/a.py", "E", "Literal[C]");
Ok(())
}
#[test]
fn resolve_base_class_by_name() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_file(
"src/mod.py",
r#"
class Base:
pass
class Sub(Base):
pass"#,
)?;
let mod_file = system_path_to_file(&db, "src/mod.py").expect("Expected file to exist.");
let ty = public_symbol_ty_by_name(&db, mod_file, "Sub").expect("Symbol type to exist");
let Type::Class(class_id) = ty else {
panic!("Sub is not a Class")
};
let context = TypingContext::global(&db);
let base_names: Vec<_> = class_id
.lookup(&context)
.bases()
.iter()
.map(|base_ty| format!("{}", base_ty.display(&context)))
.collect();
assert_eq!(base_names, vec!["Literal[Base]"]);
Ok(())
}
#[test]
fn resolve_method() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_file(
"src/mod.py",
"
class C:
def f(self): pass
",
)?;
let mod_file = system_path_to_file(&db, "src/mod.py").unwrap();
let ty = public_symbol_ty_by_name(&db, mod_file, "C").unwrap();
let Type::Class(class_id) = ty else {
panic!("C is not a Class");
};
let context = TypingContext::global(&db);
let member_ty = class_id.class_member(&context, &Name::new("f"));
let Some(Type::Function(func_id)) = member_ty else {
panic!("C.f is not a Function");
};
let function_ty = func_id.lookup(&context);
assert_eq!(function_ty.name(), "f");
Ok(())
}
#[test]
fn resolve_module_member() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_files([
("src/a.py", "import b; D = b.C"),
("src/b.py", "class C: pass"),
])?;
assert_public_ty(&db, "src/a.py", "D", "Literal[C]");
Ok(())
}
#[test]
fn resolve_literal() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_file("src/a.py", "x = 1")?;
assert_public_ty(&db, "src/a.py", "x", "Literal[1]");
Ok(())
}
#[test]
fn resolve_union() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_file(
"src/a.py",
"
if flag:
x = 1
else:
x = 2
",
)?;
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]");
Ok(())
}
#[test]
fn literal_int_arithmetic() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_file(
"src/a.py",
"
a = 2 + 1
b = a - 4
c = a * b
d = c / 3
e = 5 % 3
",
)?;
assert_public_ty(&db, "src/a.py", "a", "Literal[3]");
assert_public_ty(&db, "src/a.py", "b", "Literal[-1]");
assert_public_ty(&db, "src/a.py", "c", "Literal[-3]");
assert_public_ty(&db, "src/a.py", "d", "Literal[-1]");
assert_public_ty(&db, "src/a.py", "e", "Literal[2]");
Ok(())
}
#[test]
fn walrus() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system()
.write_file("src/a.py", "x = (y := 1) + 1")?;
assert_public_ty(&db, "src/a.py", "x", "Literal[2]");
assert_public_ty(&db, "src/a.py", "y", "Literal[1]");
Ok(())
}
#[test]
fn ifexpr() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system()
.write_file("src/a.py", "x = 1 if flag else 2")?;
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]");
Ok(())
}
#[test]
fn ifexpr_walrus() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_file(
"src/a.py",
"
y = z = 0
x = (y := 1) if flag else (z := 2)
a = y
b = z
",
)?;
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]");
assert_public_ty(&db, "src/a.py", "a", "Literal[0, 1]");
assert_public_ty(&db, "src/a.py", "b", "Literal[0, 2]");
Ok(())
}
#[test]
fn ifexpr_nested() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system()
.write_file("src/a.py", "x = 1 if flag else 2 if flag2 else 3")?;
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2, 3]");
Ok(())
}
#[test]
fn none() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system()
.write_file("src/a.py", "x = 1 if flag else None")?;
assert_public_ty(&db, "src/a.py", "x", "Literal[1] | None");
Ok(())
}
}