mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-04 10:49:50 +00:00
red-knot(Salsa): Types without refinements (#11899)
This commit is contained in:
parent
a26bd01be2
commit
22733cb7c7
13 changed files with 2169 additions and 147 deletions
|
@ -20,6 +20,7 @@ ruff_text_size = { workspace = true }
|
|||
|
||||
bitflags = { workspace = true }
|
||||
is-macro = { workspace = true }
|
||||
indexmap = { workspace = true, optional = true }
|
||||
salsa = { workspace = true, optional = true }
|
||||
smallvec = { workspace = true, optional = true }
|
||||
smol_str = { workspace = true }
|
||||
|
@ -36,4 +37,4 @@ tempfile = { workspace = true }
|
|||
workspace = true
|
||||
|
||||
[features]
|
||||
red_knot = ["dep:salsa", "dep:tracing", "dep:hashbrown", "dep:smallvec"]
|
||||
red_knot = ["dep:salsa", "dep:tracing", "dep:hashbrown", "dep:smallvec", "dep:indexmap"]
|
||||
|
|
|
@ -6,20 +6,27 @@ use crate::module::resolver::{
|
|||
file_to_module, internal::ModuleNameIngredient, internal::ModuleResolverSearchPaths,
|
||||
resolve_module_query,
|
||||
};
|
||||
|
||||
use crate::red_knot::semantic_index::symbol::ScopeId;
|
||||
use crate::red_knot::semantic_index::{scopes_map, semantic_index, symbol_table};
|
||||
use crate::red_knot::semantic_index::symbol::{
|
||||
public_symbols_map, scopes_map, PublicSymbolId, ScopeId,
|
||||
};
|
||||
use crate::red_knot::semantic_index::{root_scope, semantic_index, symbol_table};
|
||||
use crate::red_knot::types::{infer_types, public_symbol_ty};
|
||||
|
||||
#[salsa::jar(db=Db)]
|
||||
pub struct Jar(
|
||||
ModuleNameIngredient,
|
||||
ModuleResolverSearchPaths,
|
||||
ScopeId,
|
||||
PublicSymbolId,
|
||||
symbol_table,
|
||||
resolve_module_query,
|
||||
file_to_module,
|
||||
scopes_map,
|
||||
root_scope,
|
||||
semantic_index,
|
||||
infer_types,
|
||||
public_symbol_ty,
|
||||
public_symbols_map,
|
||||
);
|
||||
|
||||
/// Database giving access to semantic information about a Python program.
|
||||
|
@ -27,9 +34,13 @@ pub trait Db: SourceDb + DbWithJar<Jar> + Upcast<dyn SourceDb> {}
|
|||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use std::fmt::Formatter;
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use salsa::DebugWithDb;
|
||||
use salsa::ingredient::Ingredient;
|
||||
use salsa::storage::HasIngredientsFor;
|
||||
use salsa::{AsId, DebugWithDb};
|
||||
|
||||
use ruff_db::file_system::{FileSystem, MemoryFileSystem, OsFileSystem};
|
||||
use ruff_db::vfs::Vfs;
|
||||
|
@ -86,7 +97,7 @@ pub(crate) mod tests {
|
|||
///
|
||||
/// ## Panics
|
||||
/// If there are any pending salsa snapshots.
|
||||
pub(crate) fn take_sale_events(&mut self) -> Vec<salsa::Event> {
|
||||
pub(crate) fn take_salsa_events(&mut self) -> Vec<salsa::Event> {
|
||||
let inner = Arc::get_mut(&mut self.events).expect("no pending salsa snapshots");
|
||||
|
||||
let events = inner.get_mut().unwrap();
|
||||
|
@ -98,7 +109,7 @@ pub(crate) mod tests {
|
|||
/// ## Panics
|
||||
/// If there are any pending salsa snapshots.
|
||||
pub(crate) fn clear_salsa_events(&mut self) {
|
||||
self.take_sale_events();
|
||||
self.take_salsa_events();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -150,4 +161,106 @@ pub(crate) mod tests {
|
|||
#[allow(unused)]
|
||||
Os(OsFileSystem),
|
||||
}
|
||||
|
||||
pub(crate) fn assert_will_run_function_query<C, Db, Jar>(
|
||||
db: &Db,
|
||||
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
|
||||
key: C::Key,
|
||||
events: &[salsa::Event],
|
||||
) where
|
||||
C: salsa::function::Configuration<Jar = Jar>
|
||||
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
|
||||
Jar: HasIngredientsFor<C>,
|
||||
Db: salsa::DbWithJar<Jar>,
|
||||
C::Key: AsId,
|
||||
{
|
||||
will_run_function_query(db, to_function, key, events, true);
|
||||
}
|
||||
|
||||
pub(crate) fn assert_will_not_run_function_query<C, Db, Jar>(
|
||||
db: &Db,
|
||||
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
|
||||
key: C::Key,
|
||||
events: &[salsa::Event],
|
||||
) where
|
||||
C: salsa::function::Configuration<Jar = Jar>
|
||||
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
|
||||
Jar: HasIngredientsFor<C>,
|
||||
Db: salsa::DbWithJar<Jar>,
|
||||
C::Key: AsId,
|
||||
{
|
||||
will_run_function_query(db, to_function, key, events, false);
|
||||
}
|
||||
|
||||
fn will_run_function_query<C, Db, Jar>(
|
||||
db: &Db,
|
||||
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
|
||||
key: C::Key,
|
||||
events: &[salsa::Event],
|
||||
should_run: bool,
|
||||
) where
|
||||
C: salsa::function::Configuration<Jar = Jar>
|
||||
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
|
||||
Jar: HasIngredientsFor<C>,
|
||||
Db: salsa::DbWithJar<Jar>,
|
||||
C::Key: AsId,
|
||||
{
|
||||
let (jar, _) =
|
||||
<_ as salsa::storage::HasJar<<C as salsa::storage::IngredientsFor>::Jar>>::jar(db);
|
||||
let ingredient = jar.ingredient();
|
||||
|
||||
let function_ingredient = to_function(ingredient);
|
||||
|
||||
let ingredient_index =
|
||||
<salsa::function::FunctionIngredient<C> as Ingredient<Db>>::ingredient_index(
|
||||
function_ingredient,
|
||||
);
|
||||
|
||||
let did_run = events.iter().any(|event| {
|
||||
if let salsa::EventKind::WillExecute { database_key } = event.kind {
|
||||
database_key.ingredient_index() == ingredient_index
|
||||
&& database_key.key_index() == key.as_id()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
|
||||
if should_run && !did_run {
|
||||
panic!(
|
||||
"Expected query {:?} to run but it didn't",
|
||||
DebugIdx {
|
||||
db: PhantomData::<Db>,
|
||||
value_id: key.as_id(),
|
||||
ingredient: function_ingredient,
|
||||
}
|
||||
);
|
||||
} else if !should_run && did_run {
|
||||
panic!(
|
||||
"Expected query {:?} not to run but it did",
|
||||
DebugIdx {
|
||||
db: PhantomData::<Db>,
|
||||
value_id: key.as_id(),
|
||||
ingredient: function_ingredient,
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
struct DebugIdx<'a, I, Db>
|
||||
where
|
||||
I: Ingredient<Db>,
|
||||
{
|
||||
value_id: salsa::Id,
|
||||
ingredient: &'a I,
|
||||
db: PhantomData<Db>,
|
||||
}
|
||||
|
||||
impl<'a, I, Db> std::fmt::Debug for DebugIdx<'a, I, Db>
|
||||
where
|
||||
I: Ingredient<Db>,
|
||||
{
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
self.ingredient.fmt_index(Some(self.value_id), f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -886,7 +886,7 @@ mod tests {
|
|||
let foo_module2 = resolve_module(&db, foo_module_name);
|
||||
|
||||
assert!(!db
|
||||
.take_sale_events()
|
||||
.take_salsa_events()
|
||||
.iter()
|
||||
.any(|event| { matches!(event.kind, salsa::EventKind::WillExecute { .. }) }));
|
||||
|
||||
|
|
|
@ -1,3 +1,8 @@
|
|||
use rustc_hash::FxHasher;
|
||||
use std::hash::BuildHasherDefault;
|
||||
|
||||
pub mod ast_node_ref;
|
||||
mod node_key;
|
||||
pub mod semantic_index;
|
||||
pub mod types;
|
||||
pub(crate) type FxIndexSet<V> = indexmap::set::IndexSet<V, BuildHasherDefault<FxHasher>>;
|
||||
|
|
|
@ -9,10 +9,10 @@ use ruff_index::{IndexSlice, IndexVec};
|
|||
use ruff_python_ast as ast;
|
||||
|
||||
use crate::red_knot::node_key::NodeKey;
|
||||
use crate::red_knot::semantic_index::ast_ids::AstIds;
|
||||
use crate::red_knot::semantic_index::ast_ids::{AstId, AstIds, ScopeClassId, ScopeFunctionId};
|
||||
use crate::red_knot::semantic_index::builder::SemanticIndexBuilder;
|
||||
use crate::red_knot::semantic_index::symbol::{
|
||||
FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeSymbolId, ScopesMap, SymbolTable,
|
||||
FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable,
|
||||
};
|
||||
use crate::Db;
|
||||
|
||||
|
@ -21,7 +21,7 @@ mod builder;
|
|||
pub mod definition;
|
||||
pub mod symbol;
|
||||
|
||||
type SymbolMap = hashbrown::HashMap<ScopeSymbolId, (), ()>;
|
||||
type SymbolMap = hashbrown::HashMap<ScopedSymbolId, (), ()>;
|
||||
|
||||
/// Returns the semantic index for `file`.
|
||||
///
|
||||
|
@ -42,33 +42,22 @@ pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex {
|
|||
pub(crate) fn symbol_table(db: &dyn Db, scope: ScopeId) -> Arc<SymbolTable> {
|
||||
let index = semantic_index(db, scope.file(db));
|
||||
|
||||
index.symbol_table(scope.scope_id(db))
|
||||
}
|
||||
|
||||
/// Returns a mapping from file specific [`FileScopeId`] to a program-wide unique [`ScopeId`].
|
||||
#[salsa::tracked(return_ref)]
|
||||
pub(crate) fn scopes_map(db: &dyn Db, file: VfsFile) -> ScopesMap {
|
||||
let index = semantic_index(db, file);
|
||||
|
||||
let scopes: IndexVec<_, _> = index
|
||||
.scopes
|
||||
.indices()
|
||||
.map(|id| ScopeId::new(db, file, id))
|
||||
.collect();
|
||||
|
||||
ScopesMap::new(scopes)
|
||||
index.symbol_table(scope.file_scope_id(db))
|
||||
}
|
||||
|
||||
/// Returns the root scope of `file`.
|
||||
pub fn root_scope(db: &dyn Db, file: VfsFile) -> ScopeId {
|
||||
#[salsa::tracked]
|
||||
pub(crate) fn root_scope(db: &dyn Db, file: VfsFile) -> ScopeId {
|
||||
FileScopeId::root().to_scope_id(db, file)
|
||||
}
|
||||
|
||||
/// Returns the symbol with the given name in `file`'s public scope or `None` if
|
||||
/// no symbol with the given name exists.
|
||||
pub fn global_symbol(db: &dyn Db, file: VfsFile, name: &str) -> Option<PublicSymbolId> {
|
||||
pub fn public_symbol(db: &dyn Db, file: VfsFile, name: &str) -> Option<PublicSymbolId> {
|
||||
let root_scope = root_scope(db, file);
|
||||
root_scope.symbol(db, name)
|
||||
let symbol_table = symbol_table(db, root_scope);
|
||||
let local = symbol_table.symbol_id_by_name(name)?;
|
||||
Some(local.to_public_symbol(db, file))
|
||||
}
|
||||
|
||||
/// The symbol tables for an entire file.
|
||||
|
@ -90,6 +79,9 @@ pub struct SemanticIndex {
|
|||
/// Note: We should not depend on this map when analysing other files or
|
||||
/// changing a file invalidates all dependents.
|
||||
ast_ids: IndexVec<FileScopeId, AstIds>,
|
||||
|
||||
/// Map from scope to the node that introduces the scope.
|
||||
scope_nodes: IndexVec<FileScopeId, NodeWithScopeId>,
|
||||
}
|
||||
|
||||
impl SemanticIndex {
|
||||
|
@ -97,7 +89,7 @@ impl SemanticIndex {
|
|||
///
|
||||
/// Use the Salsa cached [`symbol_table`] query if you only need the
|
||||
/// symbol table for a single scope.
|
||||
fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable> {
|
||||
pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable> {
|
||||
self.symbol_tables[scope_id].clone()
|
||||
}
|
||||
|
||||
|
@ -152,6 +144,10 @@ impl SemanticIndex {
|
|||
pub(crate) fn ancestor_scopes(&self, scope: FileScopeId) -> AncestorsIter {
|
||||
AncestorsIter::new(self, scope)
|
||||
}
|
||||
|
||||
pub(crate) fn scope_node(&self, scope_id: FileScopeId) -> NodeWithScopeId {
|
||||
self.scope_nodes[scope_id]
|
||||
}
|
||||
}
|
||||
|
||||
/// ID that uniquely identifies an expression inside a [`Scope`].
|
||||
|
@ -246,6 +242,28 @@ impl<'a> Iterator for ChildrenIter<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub(crate) enum NodeWithScopeId {
|
||||
Module,
|
||||
Class(AstId<ScopeClassId>),
|
||||
ClassTypeParams(AstId<ScopeClassId>),
|
||||
Function(AstId<ScopeFunctionId>),
|
||||
FunctionTypeParams(AstId<ScopeFunctionId>),
|
||||
}
|
||||
|
||||
impl NodeWithScopeId {
|
||||
fn scope_kind(self) -> ScopeKind {
|
||||
match self {
|
||||
NodeWithScopeId::Module => ScopeKind::Module,
|
||||
NodeWithScopeId::Class(_) => ScopeKind::Class,
|
||||
NodeWithScopeId::Function(_) => ScopeKind::Function,
|
||||
NodeWithScopeId::ClassTypeParams(_) | NodeWithScopeId::FunctionTypeParams(_) => {
|
||||
ScopeKind::Annotation
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FusedIterator for ChildrenIter<'_> {}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -583,19 +601,14 @@ class C[T]:
|
|||
let TestCase { db, file } = test_case("x = 1;\ndef test():\n y = 4");
|
||||
|
||||
let index = semantic_index(&db, file);
|
||||
let root_table = index.symbol_table(FileScopeId::root());
|
||||
let parsed = parsed_module(&db, file);
|
||||
let ast = parsed.syntax();
|
||||
|
||||
let x_sym = root_table
|
||||
.symbol_by_name("x")
|
||||
.expect("x symbol should exist");
|
||||
|
||||
let x_stmt = ast.body[0].as_assign_stmt().unwrap();
|
||||
let x = &x_stmt.targets[0];
|
||||
|
||||
assert_eq!(index.expression_scope(x).kind(), ScopeKind::Module);
|
||||
assert_eq!(index.expression_scope_id(x), x_sym.scope());
|
||||
assert_eq!(index.expression_scope_id(x), FileScopeId::root());
|
||||
|
||||
let def = ast.body[1].as_function_def_stmt().unwrap();
|
||||
let y_stmt = def.body[0].as_assign_stmt().unwrap();
|
||||
|
|
|
@ -66,13 +66,13 @@ impl std::fmt::Debug for AstIds {
|
|||
}
|
||||
|
||||
fn ast_ids(db: &dyn Db, scope: ScopeId) -> &AstIds {
|
||||
semantic_index(db, scope.file(db)).ast_ids(scope.scope_id(db))
|
||||
semantic_index(db, scope.file(db)).ast_ids(scope.file_scope_id(db))
|
||||
}
|
||||
|
||||
/// Node that can be uniquely identified by an id in a [`FileScopeId`].
|
||||
pub trait ScopeAstIdNode {
|
||||
/// The type of the ID uniquely identifying the node.
|
||||
type Id;
|
||||
type Id: Copy;
|
||||
|
||||
/// Returns the ID that uniquely identifies the node in `scope`.
|
||||
///
|
||||
|
@ -91,7 +91,7 @@ pub trait ScopeAstIdNode {
|
|||
|
||||
/// Extension trait for AST nodes that can be resolved by an `AstId`.
|
||||
pub trait AstIdNode {
|
||||
type ScopeId;
|
||||
type ScopeId: Copy;
|
||||
|
||||
/// Resolves the AST id of the node.
|
||||
///
|
||||
|
@ -133,7 +133,7 @@ where
|
|||
|
||||
/// Uniquely identifies an AST node in a file.
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct AstId<L> {
|
||||
pub struct AstId<L: Copy> {
|
||||
/// The node's scope.
|
||||
scope: FileScopeId,
|
||||
|
||||
|
@ -141,6 +141,16 @@ pub struct AstId<L> {
|
|||
in_scope_id: L,
|
||||
}
|
||||
|
||||
impl<L: Copy> AstId<L> {
|
||||
pub(super) fn new(scope: FileScopeId, in_scope_id: L) -> Self {
|
||||
Self { scope, in_scope_id }
|
||||
}
|
||||
|
||||
pub(super) fn in_scope_id(self) -> L {
|
||||
self.in_scope_id
|
||||
}
|
||||
}
|
||||
|
||||
/// Uniquely identifies an [`ast::Expr`] in a [`FileScopeId`].
|
||||
#[newtype_index]
|
||||
pub struct ScopeExpressionId;
|
||||
|
|
|
@ -10,16 +10,16 @@ use ruff_python_ast::visitor::{walk_expr, walk_stmt, Visitor};
|
|||
use crate::name::Name;
|
||||
use crate::red_knot::node_key::NodeKey;
|
||||
use crate::red_knot::semantic_index::ast_ids::{
|
||||
AstIdsBuilder, ScopeAssignmentId, ScopeClassId, ScopeFunctionId, ScopeImportFromId,
|
||||
AstId, AstIdsBuilder, ScopeAssignmentId, ScopeClassId, ScopeFunctionId, ScopeImportFromId,
|
||||
ScopeImportId, ScopeNamedExprId,
|
||||
};
|
||||
use crate::red_knot::semantic_index::definition::{
|
||||
Definition, ImportDefinition, ImportFromDefinition,
|
||||
};
|
||||
use crate::red_knot::semantic_index::symbol::{
|
||||
FileScopeId, FileSymbolId, Scope, ScopeKind, ScopeSymbolId, SymbolFlags, SymbolTableBuilder,
|
||||
FileScopeId, FileSymbolId, Scope, ScopedSymbolId, SymbolFlags, SymbolTableBuilder,
|
||||
};
|
||||
use crate::red_knot::semantic_index::SemanticIndex;
|
||||
use crate::red_knot::semantic_index::{NodeWithScopeId, SemanticIndex};
|
||||
|
||||
pub(super) struct SemanticIndexBuilder<'a> {
|
||||
// Builder state
|
||||
|
@ -33,6 +33,7 @@ pub(super) struct SemanticIndexBuilder<'a> {
|
|||
symbol_tables: IndexVec<FileScopeId, SymbolTableBuilder>,
|
||||
ast_ids: IndexVec<FileScopeId, AstIdsBuilder>,
|
||||
expression_scopes: FxHashMap<NodeKey, FileScopeId>,
|
||||
scope_nodes: IndexVec<FileScopeId, NodeWithScopeId>,
|
||||
}
|
||||
|
||||
impl<'a> SemanticIndexBuilder<'a> {
|
||||
|
@ -46,10 +47,11 @@ impl<'a> SemanticIndexBuilder<'a> {
|
|||
symbol_tables: IndexVec::new(),
|
||||
ast_ids: IndexVec::new(),
|
||||
expression_scopes: FxHashMap::default(),
|
||||
scope_nodes: IndexVec::new(),
|
||||
};
|
||||
|
||||
builder.push_scope_with_parent(
|
||||
ScopeKind::Module,
|
||||
NodeWithScopeId::Module,
|
||||
&Name::new_static("<module>"),
|
||||
None,
|
||||
None,
|
||||
|
@ -68,18 +70,18 @@ impl<'a> SemanticIndexBuilder<'a> {
|
|||
|
||||
fn push_scope(
|
||||
&mut self,
|
||||
scope_kind: ScopeKind,
|
||||
node: NodeWithScopeId,
|
||||
name: &Name,
|
||||
defining_symbol: Option<FileSymbolId>,
|
||||
definition: Option<Definition>,
|
||||
) {
|
||||
let parent = self.current_scope();
|
||||
self.push_scope_with_parent(scope_kind, name, defining_symbol, definition, Some(parent));
|
||||
self.push_scope_with_parent(node, name, defining_symbol, definition, Some(parent));
|
||||
}
|
||||
|
||||
fn push_scope_with_parent(
|
||||
&mut self,
|
||||
scope_kind: ScopeKind,
|
||||
node: NodeWithScopeId,
|
||||
name: &Name,
|
||||
defining_symbol: Option<FileSymbolId>,
|
||||
definition: Option<Definition>,
|
||||
|
@ -92,13 +94,17 @@ impl<'a> SemanticIndexBuilder<'a> {
|
|||
parent,
|
||||
defining_symbol,
|
||||
definition,
|
||||
kind: scope_kind,
|
||||
kind: node.scope_kind(),
|
||||
descendents: children_start..children_start,
|
||||
};
|
||||
|
||||
let scope_id = self.scopes.push(scope);
|
||||
self.symbol_tables.push(SymbolTableBuilder::new());
|
||||
self.ast_ids.push(AstIdsBuilder::new());
|
||||
let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new());
|
||||
let scope_node_id = self.scope_nodes.push(node);
|
||||
|
||||
debug_assert_eq!(ast_id_scope, scope_id);
|
||||
debug_assert_eq!(scope_id, scope_node_id);
|
||||
self.scope_stack.push(scope_id);
|
||||
}
|
||||
|
||||
|
@ -120,11 +126,10 @@ impl<'a> SemanticIndexBuilder<'a> {
|
|||
&mut self.ast_ids[scope_id]
|
||||
}
|
||||
|
||||
fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopeSymbolId {
|
||||
let scope = self.current_scope();
|
||||
fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId {
|
||||
let symbol_table = self.current_symbol_table();
|
||||
|
||||
symbol_table.add_or_update_symbol(name, scope, flags, None)
|
||||
symbol_table.add_or_update_symbol(name, flags, None)
|
||||
}
|
||||
|
||||
fn add_or_update_symbol_with_definition(
|
||||
|
@ -132,27 +137,32 @@ impl<'a> SemanticIndexBuilder<'a> {
|
|||
name: Name,
|
||||
|
||||
definition: Definition,
|
||||
) -> ScopeSymbolId {
|
||||
let scope = self.current_scope();
|
||||
) -> ScopedSymbolId {
|
||||
let symbol_table = self.current_symbol_table();
|
||||
|
||||
symbol_table.add_or_update_symbol(name, scope, SymbolFlags::IS_DEFINED, Some(definition))
|
||||
symbol_table.add_or_update_symbol(name, SymbolFlags::IS_DEFINED, Some(definition))
|
||||
}
|
||||
|
||||
fn with_type_params(
|
||||
&mut self,
|
||||
name: &Name,
|
||||
params: &Option<Box<ast::TypeParams>>,
|
||||
definition: Option<Definition>,
|
||||
with_params: &WithTypeParams,
|
||||
defining_symbol: FileSymbolId,
|
||||
nested: impl FnOnce(&mut Self) -> FileScopeId,
|
||||
) -> FileScopeId {
|
||||
if let Some(type_params) = params {
|
||||
let type_params = with_params.type_parameters();
|
||||
|
||||
if let Some(type_params) = type_params {
|
||||
let type_node = match with_params {
|
||||
WithTypeParams::ClassDef { id, .. } => NodeWithScopeId::ClassTypeParams(*id),
|
||||
WithTypeParams::FunctionDef { id, .. } => NodeWithScopeId::FunctionTypeParams(*id),
|
||||
};
|
||||
|
||||
self.push_scope(
|
||||
ScopeKind::Annotation,
|
||||
type_node,
|
||||
name,
|
||||
Some(defining_symbol),
|
||||
definition,
|
||||
Some(with_params.definition()),
|
||||
);
|
||||
for type_param in &type_params.type_params {
|
||||
let name = match type_param {
|
||||
|
@ -163,9 +173,10 @@ impl<'a> SemanticIndexBuilder<'a> {
|
|||
self.add_or_update_symbol(Name::new(name), SymbolFlags::IS_DEFINED);
|
||||
}
|
||||
}
|
||||
|
||||
let nested_scope = nested(self);
|
||||
|
||||
if params.is_some() {
|
||||
if type_params.is_some() {
|
||||
self.pop_scope();
|
||||
}
|
||||
|
||||
|
@ -198,10 +209,12 @@ impl<'a> SemanticIndexBuilder<'a> {
|
|||
ast_ids.shrink_to_fit();
|
||||
symbol_tables.shrink_to_fit();
|
||||
self.expression_scopes.shrink_to_fit();
|
||||
self.scope_nodes.shrink_to_fit();
|
||||
|
||||
SemanticIndex {
|
||||
symbol_tables,
|
||||
scopes: self.scopes,
|
||||
scope_nodes: self.scope_nodes,
|
||||
ast_ids,
|
||||
expression_scopes: self.expression_scopes,
|
||||
}
|
||||
|
@ -223,7 +236,8 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
|
|||
self.visit_decorator(decorator);
|
||||
}
|
||||
let name = Name::new(&function_def.name.id);
|
||||
let definition = Definition::FunctionDef(ScopeFunctionId(statement_id));
|
||||
let function_id = ScopeFunctionId(statement_id);
|
||||
let definition = Definition::FunctionDef(function_id);
|
||||
let scope = self.current_scope();
|
||||
let symbol = FileSymbolId::new(
|
||||
scope,
|
||||
|
@ -232,8 +246,10 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
|
|||
|
||||
self.with_type_params(
|
||||
&name,
|
||||
&function_def.type_params,
|
||||
Some(definition),
|
||||
&WithTypeParams::FunctionDef {
|
||||
node: function_def,
|
||||
id: AstId::new(scope, function_id),
|
||||
},
|
||||
symbol,
|
||||
|builder| {
|
||||
builder.visit_parameters(&function_def.parameters);
|
||||
|
@ -242,7 +258,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
|
|||
}
|
||||
|
||||
builder.push_scope(
|
||||
ScopeKind::Function,
|
||||
NodeWithScopeId::Function(AstId::new(scope, function_id)),
|
||||
&name,
|
||||
Some(symbol),
|
||||
Some(definition),
|
||||
|
@ -258,21 +274,36 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
|
|||
}
|
||||
|
||||
let name = Name::new(&class.name.id);
|
||||
let definition = Definition::from(ScopeClassId(statement_id));
|
||||
let class_id = ScopeClassId(statement_id);
|
||||
let definition = Definition::from(class_id);
|
||||
let scope = self.current_scope();
|
||||
let id = FileSymbolId::new(
|
||||
self.current_scope(),
|
||||
self.add_or_update_symbol_with_definition(name.clone(), definition),
|
||||
);
|
||||
self.with_type_params(&name, &class.type_params, Some(definition), id, |builder| {
|
||||
if let Some(arguments) = &class.arguments {
|
||||
builder.visit_arguments(arguments);
|
||||
}
|
||||
self.with_type_params(
|
||||
&name,
|
||||
&WithTypeParams::ClassDef {
|
||||
node: class,
|
||||
id: AstId::new(scope, class_id),
|
||||
},
|
||||
id,
|
||||
|builder| {
|
||||
if let Some(arguments) = &class.arguments {
|
||||
builder.visit_arguments(arguments);
|
||||
}
|
||||
|
||||
builder.push_scope(ScopeKind::Class, &name, Some(id), Some(definition));
|
||||
builder.visit_body(&class.body);
|
||||
builder.push_scope(
|
||||
NodeWithScopeId::Class(AstId::new(scope, class_id)),
|
||||
&name,
|
||||
Some(id),
|
||||
Some(definition),
|
||||
);
|
||||
builder.visit_body(&class.body);
|
||||
|
||||
builder.pop_scope()
|
||||
});
|
||||
builder.pop_scope()
|
||||
},
|
||||
);
|
||||
}
|
||||
ast::Stmt::Import(ast::StmtImport { names, .. }) => {
|
||||
for (i, alias) in names.iter().enumerate() {
|
||||
|
@ -396,3 +427,30 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum WithTypeParams<'a> {
|
||||
ClassDef {
|
||||
node: &'a ast::StmtClassDef,
|
||||
id: AstId<ScopeClassId>,
|
||||
},
|
||||
FunctionDef {
|
||||
node: &'a ast::StmtFunctionDef,
|
||||
id: AstId<ScopeFunctionId>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<'a> WithTypeParams<'a> {
|
||||
fn type_parameters(&self) -> Option<&'a ast::TypeParams> {
|
||||
match self {
|
||||
WithTypeParams::ClassDef { node, .. } => node.type_params.as_deref(),
|
||||
WithTypeParams::FunctionDef { node, .. } => node.type_params.as_deref(),
|
||||
}
|
||||
}
|
||||
|
||||
fn definition(&self) -> Definition {
|
||||
match self {
|
||||
WithTypeParams::ClassDef { id, .. } => Definition::ClassDef(id.in_scope_id()),
|
||||
WithTypeParams::FunctionDef { id, .. } => Definition::FunctionDef(id.in_scope_id()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ use crate::red_knot::semantic_index::ast_ids::{
|
|||
ScopeImportFromId, ScopeImportId, ScopeNamedExprId,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum Definition {
|
||||
Import(ImportDefinition),
|
||||
ImportFrom(ImportFromDefinition),
|
||||
|
@ -59,18 +59,18 @@ impl From<ScopeNamedExprId> for Definition {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct ImportDefinition {
|
||||
pub(super) import_id: ScopeImportId,
|
||||
pub(crate) import_id: ScopeImportId,
|
||||
|
||||
/// Index into [`ruff_python_ast::StmtImport::names`].
|
||||
pub(super) alias: u32,
|
||||
pub(crate) alias: u32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct ImportFromDefinition {
|
||||
pub(super) import_id: ScopeImportFromId,
|
||||
pub(crate) import_id: ScopeImportFromId,
|
||||
|
||||
/// Index into [`ruff_python_ast::StmtImportFrom::names`].
|
||||
pub(super) name: u32,
|
||||
pub(crate) name: u32,
|
||||
}
|
||||
|
|
|
@ -15,24 +15,21 @@ use ruff_index::{newtype_index, IndexVec};
|
|||
|
||||
use crate::name::Name;
|
||||
use crate::red_knot::semantic_index::definition::Definition;
|
||||
use crate::red_knot::semantic_index::{scopes_map, symbol_table, SymbolMap};
|
||||
use crate::red_knot::semantic_index::{root_scope, semantic_index, symbol_table, SymbolMap};
|
||||
use crate::Db;
|
||||
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
pub struct Symbol {
|
||||
name: Name,
|
||||
flags: SymbolFlags,
|
||||
scope: FileScopeId,
|
||||
|
||||
/// The nodes that define this symbol, in source order.
|
||||
definitions: SmallVec<[Definition; 4]>,
|
||||
}
|
||||
|
||||
impl Symbol {
|
||||
fn new(name: Name, scope: FileScopeId, definition: Option<Definition>) -> Self {
|
||||
fn new(name: Name, definition: Option<Definition>) -> Self {
|
||||
Self {
|
||||
name,
|
||||
scope,
|
||||
flags: SymbolFlags::empty(),
|
||||
definitions: definition.into_iter().collect(),
|
||||
}
|
||||
|
@ -51,11 +48,6 @@ impl Symbol {
|
|||
&self.name
|
||||
}
|
||||
|
||||
/// The scope in which this symbol is defined.
|
||||
pub fn scope(&self) -> FileScopeId {
|
||||
self.scope
|
||||
}
|
||||
|
||||
/// Is the symbol used in its containing scope?
|
||||
pub fn is_used(&self) -> bool {
|
||||
self.flags.contains(SymbolFlags::IS_USED)
|
||||
|
@ -84,62 +76,72 @@ bitflags! {
|
|||
}
|
||||
|
||||
/// ID that uniquely identifies a public symbol defined in a module's root scope.
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
#[salsa::tracked]
|
||||
pub struct PublicSymbolId {
|
||||
scope: ScopeId,
|
||||
symbol: ScopeSymbolId,
|
||||
}
|
||||
|
||||
impl PublicSymbolId {
|
||||
pub(crate) fn new(scope: ScopeId, symbol: ScopeSymbolId) -> Self {
|
||||
Self { scope, symbol }
|
||||
}
|
||||
|
||||
pub fn scope(self) -> ScopeId {
|
||||
self.scope
|
||||
}
|
||||
|
||||
pub(crate) fn scope_symbol(self) -> ScopeSymbolId {
|
||||
self.symbol
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PublicSymbolId> for ScopeSymbolId {
|
||||
fn from(val: PublicSymbolId) -> Self {
|
||||
val.scope_symbol()
|
||||
}
|
||||
#[id]
|
||||
pub(crate) file: VfsFile,
|
||||
#[id]
|
||||
pub(crate) scoped_symbol_id: ScopedSymbolId,
|
||||
}
|
||||
|
||||
/// ID that uniquely identifies a symbol in a file.
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct FileSymbolId {
|
||||
scope: FileScopeId,
|
||||
symbol: ScopeSymbolId,
|
||||
scoped_symbol_id: ScopedSymbolId,
|
||||
}
|
||||
|
||||
impl FileSymbolId {
|
||||
pub(super) fn new(scope: FileScopeId, symbol: ScopeSymbolId) -> Self {
|
||||
Self { scope, symbol }
|
||||
pub(super) fn new(scope: FileScopeId, symbol: ScopedSymbolId) -> Self {
|
||||
Self {
|
||||
scope,
|
||||
scoped_symbol_id: symbol,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scope(self) -> FileScopeId {
|
||||
self.scope
|
||||
}
|
||||
|
||||
pub(crate) fn symbol(self) -> ScopeSymbolId {
|
||||
self.symbol
|
||||
pub(crate) fn scoped_symbol_id(self) -> ScopedSymbolId {
|
||||
self.scoped_symbol_id
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FileSymbolId> for ScopeSymbolId {
|
||||
impl From<FileSymbolId> for ScopedSymbolId {
|
||||
fn from(val: FileSymbolId) -> Self {
|
||||
val.symbol()
|
||||
val.scoped_symbol_id()
|
||||
}
|
||||
}
|
||||
|
||||
/// Symbol ID that uniquely identifies a symbol inside a [`Scope`].
|
||||
#[newtype_index]
|
||||
pub(crate) struct ScopeSymbolId;
|
||||
pub struct ScopedSymbolId;
|
||||
|
||||
impl ScopedSymbolId {
|
||||
/// Converts the symbol to a public symbol.
|
||||
///
|
||||
/// # Panics
|
||||
/// May panic if the symbol does not belong to `file` or is not a symbol of `file`'s root scope.
|
||||
pub(crate) fn to_public_symbol(self, db: &dyn Db, file: VfsFile) -> PublicSymbolId {
|
||||
let symbols = public_symbols_map(db, file);
|
||||
symbols.public(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a mapping from [`FileScopeId`] to globally unique [`ScopeId`].
|
||||
#[salsa::tracked(return_ref)]
|
||||
pub(crate) fn scopes_map(db: &dyn Db, file: VfsFile) -> ScopesMap {
|
||||
let index = semantic_index(db, file);
|
||||
|
||||
let scopes: IndexVec<_, _> = index
|
||||
.scopes
|
||||
.indices()
|
||||
.map(|id| ScopeId::new(db, file, id))
|
||||
.collect();
|
||||
|
||||
ScopesMap { scopes }
|
||||
}
|
||||
|
||||
/// Maps from the file specific [`FileScopeId`] to the global [`ScopeId`] that can be used as a Salsa query parameter.
|
||||
///
|
||||
|
@ -152,13 +154,37 @@ pub(crate) struct ScopesMap {
|
|||
}
|
||||
|
||||
impl ScopesMap {
|
||||
pub(super) fn new(scopes: IndexVec<FileScopeId, ScopeId>) -> Self {
|
||||
Self { scopes }
|
||||
}
|
||||
|
||||
/// Gets the program-wide unique scope id for the given file specific `scope_id`.
|
||||
fn get(&self, scope_id: FileScopeId) -> ScopeId {
|
||||
self.scopes[scope_id]
|
||||
fn get(&self, scope: FileScopeId) -> ScopeId {
|
||||
self.scopes[scope]
|
||||
}
|
||||
}
|
||||
|
||||
#[salsa::tracked(return_ref)]
|
||||
pub(crate) fn public_symbols_map(db: &dyn Db, file: VfsFile) -> PublicSymbolsMap {
|
||||
let module_scope = root_scope(db, file);
|
||||
let symbols = symbol_table(db, module_scope);
|
||||
|
||||
let public_symbols: IndexVec<_, _> = symbols
|
||||
.symbol_ids()
|
||||
.map(|id| PublicSymbolId::new(db, file, id))
|
||||
.collect();
|
||||
|
||||
PublicSymbolsMap {
|
||||
symbols: public_symbols,
|
||||
}
|
||||
}
|
||||
|
||||
/// Maps [`LocalSymbolId`] of a file's root scope to the corresponding [`PublicSymbolId`] (Salsa ingredients).
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
pub(crate) struct PublicSymbolsMap {
|
||||
symbols: IndexVec<ScopedSymbolId, PublicSymbolId>,
|
||||
}
|
||||
|
||||
impl PublicSymbolsMap {
|
||||
/// Resolve the [`PublicSymbolId`] for the module-level `symbol_id`.
|
||||
fn public(&self, symbol_id: ScopedSymbolId) -> PublicSymbolId {
|
||||
self.symbols[symbol_id]
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -166,18 +192,10 @@ impl ScopesMap {
|
|||
#[salsa::tracked]
|
||||
pub struct ScopeId {
|
||||
#[allow(clippy::used_underscore_binding)]
|
||||
#[id]
|
||||
pub file: VfsFile,
|
||||
pub scope_id: FileScopeId,
|
||||
}
|
||||
|
||||
impl ScopeId {
|
||||
/// Resolves the symbol named `name` in this scope.
|
||||
pub fn symbol(self, db: &dyn Db, name: &str) -> Option<PublicSymbolId> {
|
||||
let symbol_table = symbol_table(db, self);
|
||||
let in_scope_id = symbol_table.symbol_id_by_name(name)?;
|
||||
|
||||
Some(PublicSymbolId::new(self, in_scope_id))
|
||||
}
|
||||
#[id]
|
||||
pub file_scope_id: FileScopeId,
|
||||
}
|
||||
|
||||
/// ID that uniquely identifies a scope inside of a module.
|
||||
|
@ -239,7 +257,7 @@ pub enum ScopeKind {
|
|||
#[derive(Debug)]
|
||||
pub struct SymbolTable {
|
||||
/// The symbols in this scope.
|
||||
symbols: IndexVec<ScopeSymbolId, Symbol>,
|
||||
symbols: IndexVec<ScopedSymbolId, Symbol>,
|
||||
|
||||
/// The symbols indexed by name.
|
||||
symbols_by_name: SymbolMap,
|
||||
|
@ -257,12 +275,12 @@ impl SymbolTable {
|
|||
self.symbols.shrink_to_fit();
|
||||
}
|
||||
|
||||
pub(crate) fn symbol(&self, symbol_id: impl Into<ScopeSymbolId>) -> &Symbol {
|
||||
pub(crate) fn symbol(&self, symbol_id: impl Into<ScopedSymbolId>) -> &Symbol {
|
||||
&self.symbols[symbol_id.into()]
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub(crate) fn symbol_ids(&self) -> impl Iterator<Item = ScopeSymbolId> {
|
||||
pub(crate) fn symbol_ids(&self) -> impl Iterator<Item = ScopedSymbolId> {
|
||||
self.symbols.indices()
|
||||
}
|
||||
|
||||
|
@ -277,8 +295,8 @@ impl SymbolTable {
|
|||
Some(self.symbol(id))
|
||||
}
|
||||
|
||||
/// Returns the [`ScopeSymbolId`] of the symbol named `name`.
|
||||
pub(crate) fn symbol_id_by_name(&self, name: &str) -> Option<ScopeSymbolId> {
|
||||
/// Returns the [`ScopedSymbolId`] of the symbol named `name`.
|
||||
pub(crate) fn symbol_id_by_name(&self, name: &str) -> Option<ScopedSymbolId> {
|
||||
let (id, ()) = self
|
||||
.symbols_by_name
|
||||
.raw_entry()
|
||||
|
@ -320,10 +338,9 @@ impl SymbolTableBuilder {
|
|||
pub(super) fn add_or_update_symbol(
|
||||
&mut self,
|
||||
name: Name,
|
||||
scope: FileScopeId,
|
||||
flags: SymbolFlags,
|
||||
definition: Option<Definition>,
|
||||
) -> ScopeSymbolId {
|
||||
) -> ScopedSymbolId {
|
||||
let hash = SymbolTable::hash_name(&name);
|
||||
let entry = self
|
||||
.table
|
||||
|
@ -343,7 +360,7 @@ impl SymbolTableBuilder {
|
|||
*entry.key()
|
||||
}
|
||||
RawEntryMut::Vacant(entry) => {
|
||||
let mut symbol = Symbol::new(name, scope, definition);
|
||||
let mut symbol = Symbol::new(name, definition);
|
||||
symbol.insert_flags(flags);
|
||||
|
||||
let id = self.table.symbols.push(symbol);
|
||||
|
|
684
crates/ruff_python_semantic/src/red_knot/types.rs
Normal file
684
crates/ruff_python_semantic/src/red_knot/types.rs
Normal file
|
@ -0,0 +1,684 @@
|
|||
use salsa::DebugWithDb;
|
||||
|
||||
use ruff_db::parsed::parsed_module;
|
||||
use ruff_db::vfs::VfsFile;
|
||||
use ruff_index::newtype_index;
|
||||
use ruff_python_ast as ast;
|
||||
|
||||
use crate::name::Name;
|
||||
use crate::red_knot::semantic_index::ast_ids::{AstIdNode, ScopeAstIdNode};
|
||||
use crate::red_knot::semantic_index::symbol::{FileScopeId, PublicSymbolId, ScopeId};
|
||||
use crate::red_knot::semantic_index::{
|
||||
public_symbol, root_scope, semantic_index, symbol_table, NodeWithScopeId,
|
||||
};
|
||||
use crate::red_knot::types::infer::{TypeInference, TypeInferenceBuilder};
|
||||
use crate::red_knot::FxIndexSet;
|
||||
use crate::Db;
|
||||
|
||||
mod display;
|
||||
mod infer;
|
||||
|
||||
/// Infers the type of `expr`.
|
||||
///
|
||||
/// Calling this function from a salsa query adds a dependency on [`semantic_index`]
|
||||
/// which changes with every AST change. That's why you should only call
|
||||
/// this function for the current file that's being analyzed and not for
|
||||
/// a dependency (or the query reruns whenever a dependency change).
|
||||
///
|
||||
/// Prefer [`public_symbol_ty`] when resolving the type of symbol from another file.
|
||||
#[tracing::instrument(level = "debug", skip(db))]
|
||||
pub(crate) fn expression_ty(db: &dyn Db, file: VfsFile, expression: &ast::Expr) -> Type {
|
||||
let index = semantic_index(db, file);
|
||||
let file_scope = index.expression_scope_id(expression);
|
||||
let expression_id = expression.scope_ast_id(db, file, file_scope);
|
||||
let scope = file_scope.to_scope_id(db, file);
|
||||
|
||||
infer_types(db, scope).expression_ty(expression_id)
|
||||
}
|
||||
|
||||
/// Infers the type of a public symbol.
|
||||
///
|
||||
/// This is a Salsa query to get symbol-level invalidation instead of file-level dependency invalidation.
|
||||
/// Without this being a query, changing any public type of a module would invalidate the type inference
|
||||
/// for the module scope of its dependents and the transitive dependents because.
|
||||
///
|
||||
/// For example if we have
|
||||
/// ```python
|
||||
/// # a.py
|
||||
/// import x from b
|
||||
///
|
||||
/// # b.py
|
||||
///
|
||||
/// x = 20
|
||||
/// ```
|
||||
///
|
||||
/// And x is now changed from `x = 20` to `x = 30`. The following happens:
|
||||
///
|
||||
/// * The module level types of `b.py` change because `x` now is a `Literal[30]`.
|
||||
/// * The module level types of `a.py` change because the imported symbol `x` now has a `Literal[30]` type
|
||||
/// * The module level types of any dependents of `a.py` change because the imported symbol `x` now has a `Literal[30]` type
|
||||
/// * And so on for all transitive dependencies.
|
||||
///
|
||||
/// This being a query ensures that the invalidation short-circuits if the type of this symbol didn't change.
|
||||
#[salsa::tracked]
|
||||
pub(crate) fn public_symbol_ty(db: &dyn Db, symbol: PublicSymbolId) -> Type {
|
||||
let _ = tracing::debug_span!("public_symbol_ty", "{:?}", symbol.debug(db));
|
||||
|
||||
let file = symbol.file(db);
|
||||
let scope = root_scope(db, file);
|
||||
|
||||
let inference = infer_types(db, scope);
|
||||
inference.symbol_ty(symbol.scoped_symbol_id(db))
|
||||
}
|
||||
|
||||
/// Shorthand for [`public_symbol_ty()`] that takes a symbol name instead of a [`PublicSymbolId`].
|
||||
pub fn public_symbol_ty_by_name(db: &dyn Db, file: VfsFile, name: &str) -> Option<Type> {
|
||||
let symbol = public_symbol(db, file, name)?;
|
||||
Some(public_symbol_ty(db, symbol))
|
||||
}
|
||||
|
||||
/// Infers all types for `scope`.
|
||||
#[salsa::tracked(return_ref)]
|
||||
pub(crate) fn infer_types(db: &dyn Db, scope: ScopeId) -> TypeInference {
|
||||
let file = scope.file(db);
|
||||
// Using the index here is fine because the code below depends on the AST anyway.
|
||||
// The isolation of the query is by the return inferred types.
|
||||
let index = semantic_index(db, file);
|
||||
|
||||
let scope_id = scope.file_scope_id(db);
|
||||
let node = index.scope_node(scope_id);
|
||||
|
||||
let mut context = TypeInferenceBuilder::new(db, scope, index);
|
||||
|
||||
match node {
|
||||
NodeWithScopeId::Module => {
|
||||
let parsed = parsed_module(db.upcast(), file);
|
||||
context.infer_module(parsed.syntax());
|
||||
}
|
||||
NodeWithScopeId::Class(class_id) => {
|
||||
let class = ast::StmtClassDef::lookup(db, file, class_id);
|
||||
context.infer_class_body(class);
|
||||
}
|
||||
NodeWithScopeId::ClassTypeParams(class_id) => {
|
||||
let class = ast::StmtClassDef::lookup(db, file, class_id);
|
||||
context.infer_class_type_params(class);
|
||||
}
|
||||
NodeWithScopeId::Function(function_id) => {
|
||||
let function = ast::StmtFunctionDef::lookup(db, file, function_id);
|
||||
context.infer_function_body(function);
|
||||
}
|
||||
NodeWithScopeId::FunctionTypeParams(function_id) => {
|
||||
let function = ast::StmtFunctionDef::lookup(db, file, function_id);
|
||||
context.infer_function_type_params(function);
|
||||
}
|
||||
}
|
||||
|
||||
context.finish()
|
||||
}
|
||||
|
||||
/// unique ID for a type
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum Type {
|
||||
/// the dynamic type: a statically-unknown set of values
|
||||
Any,
|
||||
/// the empty set of values
|
||||
Never,
|
||||
/// unknown type (no annotation)
|
||||
/// equivalent to Any, or to object in strict mode
|
||||
Unknown,
|
||||
/// name is not bound to any value
|
||||
Unbound,
|
||||
/// the None object (TODO remove this in favor of Instance(types.NoneType)
|
||||
None,
|
||||
/// a specific function object
|
||||
Function(TypeId<ScopedFunctionTypeId>),
|
||||
/// a specific module object
|
||||
Module(TypeId<ScopedModuleTypeId>),
|
||||
/// a specific class object
|
||||
Class(TypeId<ScopedClassTypeId>),
|
||||
/// the set of Python objects with the given class in their __class__'s method resolution order
|
||||
Instance(TypeId<ScopedClassTypeId>),
|
||||
Union(TypeId<ScopedUnionTypeId>),
|
||||
Intersection(TypeId<ScopedIntersectionTypeId>),
|
||||
IntLiteral(i64),
|
||||
// TODO protocols, callable types, overloads, generics, type vars
|
||||
}
|
||||
|
||||
impl Type {
|
||||
pub const fn is_unbound(&self) -> bool {
|
||||
matches!(self, Type::Unbound)
|
||||
}
|
||||
|
||||
pub const fn is_unknown(&self) -> bool {
|
||||
matches!(self, Type::Unknown)
|
||||
}
|
||||
|
||||
pub fn member(&self, context: &TypingContext, name: &Name) -> Option<Type> {
|
||||
match self {
|
||||
Type::Any => Some(Type::Any),
|
||||
Type::Never => todo!("attribute lookup on Never type"),
|
||||
Type::Unknown => Some(Type::Unknown),
|
||||
Type::Unbound => todo!("attribute lookup on Unbound type"),
|
||||
Type::None => todo!("attribute lookup on None type"),
|
||||
Type::Function(_) => todo!("attribute lookup on Function type"),
|
||||
Type::Module(module) => module.member(context, name),
|
||||
Type::Class(class) => class.class_member(context, name),
|
||||
Type::Instance(_) => {
|
||||
// TODO MRO? get_own_instance_member, get_instance_member
|
||||
todo!("attribute lookup on Instance type")
|
||||
}
|
||||
Type::Union(union_id) => {
|
||||
let _union = union_id.lookup(context);
|
||||
// TODO perform the get_member on each type in the union
|
||||
// TODO return the union of those results
|
||||
// TODO if any of those results is `None` then include Unknown in the result union
|
||||
todo!("attribute lookup on Union type")
|
||||
}
|
||||
Type::Intersection(_) => {
|
||||
// TODO perform the get_member on each type in the intersection
|
||||
// TODO return the intersection of those results
|
||||
todo!("attribute lookup on Intersection type")
|
||||
}
|
||||
Type::IntLiteral(_) => {
|
||||
// TODO raise error
|
||||
Some(Type::Unknown)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// ID that uniquely identifies a type in a program.
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct TypeId<L> {
|
||||
/// The scope in which this type is defined or was created.
|
||||
scope: ScopeId,
|
||||
/// The type's local ID in its scope.
|
||||
scoped: L,
|
||||
}
|
||||
|
||||
impl<Id> TypeId<Id>
|
||||
where
|
||||
Id: Copy,
|
||||
{
|
||||
pub fn scope(&self) -> ScopeId {
|
||||
self.scope
|
||||
}
|
||||
|
||||
pub fn scoped_id(&self) -> Id {
|
||||
self.scoped
|
||||
}
|
||||
|
||||
/// Resolves the type ID to the actual type.
|
||||
pub(crate) fn lookup<'a>(self, context: &'a TypingContext) -> &'a Id::Ty
|
||||
where
|
||||
Id: ScopedTypeId,
|
||||
{
|
||||
let types = context.types(self.scope);
|
||||
self.scoped.lookup_scoped(types)
|
||||
}
|
||||
}
|
||||
|
||||
/// ID that uniquely identifies a type in a scope.
|
||||
pub(crate) trait ScopedTypeId {
|
||||
/// The type that this ID points to.
|
||||
type Ty;
|
||||
|
||||
/// Looks up the type in `index`.
|
||||
///
|
||||
/// ## Panics
|
||||
/// May panic if this type is from another scope than `index`, or might just return an invalid type.
|
||||
fn lookup_scoped(self, index: &TypeInference) -> &Self::Ty;
|
||||
}
|
||||
|
||||
/// ID uniquely identifying a function type in a `scope`.
|
||||
#[newtype_index]
|
||||
pub struct ScopedFunctionTypeId;
|
||||
|
||||
impl ScopedTypeId for ScopedFunctionTypeId {
|
||||
type Ty = FunctionType;
|
||||
|
||||
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
|
||||
types.function_ty(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub struct FunctionType {
|
||||
/// name of the function at definition
|
||||
name: Name,
|
||||
/// types of all decorators on this function
|
||||
decorators: Vec<Type>,
|
||||
}
|
||||
|
||||
impl FunctionType {
|
||||
fn name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub(crate) fn decorators(&self) -> &[Type] {
|
||||
self.decorators.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
#[newtype_index]
|
||||
pub struct ScopedClassTypeId;
|
||||
|
||||
impl ScopedTypeId for ScopedClassTypeId {
|
||||
type Ty = ClassType;
|
||||
|
||||
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
|
||||
types.class_ty(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl TypeId<ScopedClassTypeId> {
|
||||
/// Returns the class member of this class named `name`.
|
||||
///
|
||||
/// The member resolves to a member of the class itself or any of its bases.
|
||||
fn class_member(self, context: &TypingContext, name: &Name) -> Option<Type> {
|
||||
if let Some(member) = self.own_class_member(context, name) {
|
||||
return Some(member);
|
||||
}
|
||||
|
||||
let class = self.lookup(context);
|
||||
for base in &class.bases {
|
||||
if let Some(member) = base.member(context, name) {
|
||||
return Some(member);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns the inferred type of the class member named `name`.
|
||||
fn own_class_member(self, context: &TypingContext, name: &Name) -> Option<Type> {
|
||||
let class = self.lookup(context);
|
||||
|
||||
let symbols = symbol_table(context.db, class.body_scope);
|
||||
let symbol = symbols.symbol_id_by_name(name)?;
|
||||
let types = context.types(class.body_scope);
|
||||
|
||||
Some(types.symbol_ty(symbol))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub struct ClassType {
|
||||
/// Name of the class at definition
|
||||
name: Name,
|
||||
|
||||
/// Types of all class bases
|
||||
bases: Vec<Type>,
|
||||
|
||||
body_scope: ScopeId,
|
||||
}
|
||||
|
||||
impl ClassType {
|
||||
fn name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub(super) fn bases(&self) -> &[Type] {
|
||||
self.bases.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
#[newtype_index]
|
||||
pub struct ScopedUnionTypeId;
|
||||
|
||||
impl ScopedTypeId for ScopedUnionTypeId {
|
||||
type Ty = UnionType;
|
||||
|
||||
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
|
||||
types.union_ty(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub struct UnionType {
|
||||
// the union type includes values in any of these types
|
||||
elements: FxIndexSet<Type>,
|
||||
}
|
||||
|
||||
struct UnionTypeBuilder<'a> {
|
||||
elements: FxIndexSet<Type>,
|
||||
context: &'a TypingContext<'a>,
|
||||
}
|
||||
|
||||
impl<'a> UnionTypeBuilder<'a> {
|
||||
fn new(context: &'a TypingContext<'a>) -> Self {
|
||||
Self {
|
||||
context,
|
||||
elements: FxIndexSet::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a type to this union.
|
||||
fn add(mut self, ty: Type) -> Self {
|
||||
match ty {
|
||||
Type::Union(union_id) => {
|
||||
let union = union_id.lookup(self.context);
|
||||
self.elements.extend(&union.elements);
|
||||
}
|
||||
_ => {
|
||||
self.elements.insert(ty);
|
||||
}
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
fn build(self) -> UnionType {
|
||||
UnionType {
|
||||
elements: self.elements,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[newtype_index]
|
||||
pub struct ScopedIntersectionTypeId;
|
||||
|
||||
impl ScopedTypeId for ScopedIntersectionTypeId {
|
||||
type Ty = IntersectionType;
|
||||
|
||||
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
|
||||
types.intersection_ty(self)
|
||||
}
|
||||
}
|
||||
|
||||
// Negation types aren't expressible in annotations, and are most likely to arise from type
|
||||
// narrowing along with intersections (e.g. `if not isinstance(...)`), so we represent them
|
||||
// directly in intersections rather than as a separate type. This sacrifices some efficiency in the
|
||||
// case where a Not appears outside an intersection (unclear when that could even happen, but we'd
|
||||
// have to represent it as a single-element intersection if it did) in exchange for better
|
||||
// efficiency in the within-intersection case.
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct IntersectionType {
|
||||
// the intersection type includes only values in all of these types
|
||||
positive: FxIndexSet<Type>,
|
||||
// the intersection type does not include any value in any of these types
|
||||
negative: FxIndexSet<Type>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct ScopedModuleTypeId;
|
||||
|
||||
impl ScopedTypeId for ScopedModuleTypeId {
|
||||
type Ty = ModuleType;
|
||||
|
||||
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
|
||||
types.module_ty()
|
||||
}
|
||||
}
|
||||
|
||||
impl TypeId<ScopedModuleTypeId> {
|
||||
fn member(self, context: &TypingContext, name: &Name) -> Option<Type> {
|
||||
context.public_symbol_ty(self.scope.file(context.db), name)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub struct ModuleType {
|
||||
file: VfsFile,
|
||||
}
|
||||
|
||||
/// Context in which to resolve types.
|
||||
///
|
||||
/// This abstraction is necessary to support a uniform API that can be used
|
||||
/// while in the process of building the type inference structure for a scope
|
||||
/// but also when all types should be resolved by querying the db.
|
||||
pub struct TypingContext<'a> {
|
||||
db: &'a dyn Db,
|
||||
|
||||
/// The Local type inference scope that is in the process of being built.
|
||||
///
|
||||
/// Bypass the `db` when resolving the types for this scope.
|
||||
local: Option<(ScopeId, &'a TypeInference)>,
|
||||
}
|
||||
|
||||
impl<'a> TypingContext<'a> {
|
||||
/// Creates a context that resolves all types by querying the db.
|
||||
#[allow(unused)]
|
||||
pub(super) fn global(db: &'a dyn Db) -> Self {
|
||||
Self { db, local: None }
|
||||
}
|
||||
|
||||
/// Creates a context that by-passes the `db` when resolving types from `scope_id` and instead uses `types`.
|
||||
fn scoped(db: &'a dyn Db, scope_id: ScopeId, types: &'a TypeInference) -> Self {
|
||||
Self {
|
||||
db,
|
||||
local: Some((scope_id, types)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [`TypeInference`] results (not guaranteed to be complete) for `scope_id`.
|
||||
fn types(&self, scope_id: ScopeId) -> &'a TypeInference {
|
||||
if let Some((scope, local_types)) = self.local {
|
||||
if scope == scope_id {
|
||||
return local_types;
|
||||
}
|
||||
}
|
||||
|
||||
infer_types(self.db, scope_id)
|
||||
}
|
||||
|
||||
fn module_ty(&self, file: VfsFile) -> Type {
|
||||
let scope = root_scope(self.db, file);
|
||||
|
||||
Type::Module(TypeId {
|
||||
scope,
|
||||
scoped: ScopedModuleTypeId,
|
||||
})
|
||||
}
|
||||
|
||||
/// Resolves the public type of a symbol named `name` defined in `file`.
|
||||
///
|
||||
/// This function calls [`public_symbol_ty`] if the local scope isn't the module scope of `file`.
|
||||
/// It otherwise tries to resolve the symbol type locally.
|
||||
fn public_symbol_ty(&self, file: VfsFile, name: &Name) -> Option<Type> {
|
||||
let symbol = public_symbol(self.db, file, name)?;
|
||||
|
||||
if let Some((scope, local_types)) = self.local {
|
||||
if scope.file_scope_id(self.db) == FileScopeId::root() && scope.file(self.db) == file {
|
||||
return Some(local_types.symbol_ty(symbol.scoped_symbol_id(self.db)));
|
||||
}
|
||||
}
|
||||
|
||||
Some(public_symbol_ty(self.db, symbol))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use ruff_db::file_system::FileSystemPathBuf;
|
||||
use ruff_db::parsed::parsed_module;
|
||||
use ruff_db::vfs::system_path_to_file;
|
||||
|
||||
use crate::db::tests::{
|
||||
assert_will_not_run_function_query, assert_will_run_function_query, TestDb,
|
||||
};
|
||||
use crate::module::resolver::{set_module_resolution_settings, ModuleResolutionSettings};
|
||||
use crate::red_knot::semantic_index::root_scope;
|
||||
use crate::red_knot::types::{
|
||||
expression_ty, infer_types, public_symbol_ty_by_name, TypingContext,
|
||||
};
|
||||
|
||||
fn setup_db() -> TestDb {
|
||||
let mut db = TestDb::new();
|
||||
set_module_resolution_settings(
|
||||
&mut db,
|
||||
ModuleResolutionSettings {
|
||||
extra_paths: vec![],
|
||||
workspace_root: FileSystemPathBuf::from("/src"),
|
||||
site_packages: None,
|
||||
custom_typeshed: None,
|
||||
},
|
||||
);
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_inference() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_file("/src/a.py", "x = 10")?;
|
||||
let a = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||
|
||||
let parsed = parsed_module(&db, a);
|
||||
|
||||
let statement = parsed.suite().first().unwrap().as_assign_stmt().unwrap();
|
||||
|
||||
let literal_ty = expression_ty(&db, a, &statement.value);
|
||||
|
||||
assert_eq!(
|
||||
format!("{}", literal_ty.display(&TypingContext::global(&db))),
|
||||
"Literal[10]"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dependency_public_symbol_type_change() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.memory_file_system().write_files([
|
||||
("/src/a.py", "from foo import x"),
|
||||
("/src/foo.py", "x = 10\ndef foo(): ..."),
|
||||
])?;
|
||||
|
||||
let a = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
x_ty.display(&TypingContext::global(&db)).to_string(),
|
||||
"Literal[10]"
|
||||
);
|
||||
|
||||
// Change `x` to a different value
|
||||
db.memory_file_system()
|
||||
.write_file("/src/foo.py", "x = 20\ndef foo(): ...")?;
|
||||
|
||||
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
|
||||
foo.touch(&mut db);
|
||||
|
||||
let a = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||
|
||||
db.clear_salsa_events();
|
||||
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
x_ty_2.display(&TypingContext::global(&db)).to_string(),
|
||||
"Literal[20]"
|
||||
);
|
||||
|
||||
let events = db.take_salsa_events();
|
||||
|
||||
let a_root_scope = root_scope(&db, a);
|
||||
assert_will_run_function_query::<infer_types, _, _>(
|
||||
&db,
|
||||
|ty| &ty.function,
|
||||
a_root_scope,
|
||||
&events,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dependency_non_public_symbol_change() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.memory_file_system().write_files([
|
||||
("/src/a.py", "from foo import x"),
|
||||
("/src/foo.py", "x = 10\ndef foo(): y = 1"),
|
||||
])?;
|
||||
|
||||
let a = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
x_ty.display(&TypingContext::global(&db)).to_string(),
|
||||
"Literal[10]"
|
||||
);
|
||||
|
||||
db.memory_file_system()
|
||||
.write_file("/src/foo.py", "x = 10\ndef foo(): pass")?;
|
||||
|
||||
let a = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
|
||||
|
||||
foo.touch(&mut db);
|
||||
|
||||
db.clear_salsa_events();
|
||||
|
||||
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
x_ty_2.display(&TypingContext::global(&db)).to_string(),
|
||||
"Literal[10]"
|
||||
);
|
||||
|
||||
let events = db.take_salsa_events();
|
||||
|
||||
let a_root_scope = root_scope(&db, a);
|
||||
|
||||
assert_will_not_run_function_query::<infer_types, _, _>(
|
||||
&db,
|
||||
|ty| &ty.function,
|
||||
a_root_scope,
|
||||
&events,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dependency_unrelated_public_symbol() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.memory_file_system().write_files([
|
||||
("/src/a.py", "from foo import x"),
|
||||
("/src/foo.py", "x = 10\ny = 20"),
|
||||
])?;
|
||||
|
||||
let a = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
x_ty.display(&TypingContext::global(&db)).to_string(),
|
||||
"Literal[10]"
|
||||
);
|
||||
|
||||
db.memory_file_system()
|
||||
.write_file("/src/foo.py", "x = 10\ny = 30")?;
|
||||
|
||||
let a = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
|
||||
|
||||
foo.touch(&mut db);
|
||||
|
||||
db.clear_salsa_events();
|
||||
|
||||
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
x_ty_2.display(&TypingContext::global(&db)).to_string(),
|
||||
"Literal[10]"
|
||||
);
|
||||
|
||||
let events = db.take_salsa_events();
|
||||
|
||||
let a_root_scope = root_scope(&db, a);
|
||||
assert_will_not_run_function_query::<infer_types, _, _>(
|
||||
&db,
|
||||
|ty| &ty.function,
|
||||
a_root_scope,
|
||||
&events,
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
175
crates/ruff_python_semantic/src/red_knot/types/display.rs
Normal file
175
crates/ruff_python_semantic/src/red_knot/types/display.rs
Normal file
|
@ -0,0 +1,175 @@
|
|||
//! Display implementations for types.
|
||||
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
||||
use crate::red_knot::types::{IntersectionType, Type, TypingContext, UnionType};
|
||||
|
||||
impl Type {
|
||||
pub fn display<'a>(&'a self, context: &'a TypingContext) -> DisplayType<'a> {
|
||||
DisplayType { ty: self, context }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct DisplayType<'a> {
|
||||
ty: &'a Type,
|
||||
context: &'a TypingContext<'a>,
|
||||
}
|
||||
|
||||
impl Display for DisplayType<'_> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self.ty {
|
||||
Type::Any => f.write_str("Any"),
|
||||
Type::Never => f.write_str("Never"),
|
||||
Type::Unknown => f.write_str("Unknown"),
|
||||
Type::Unbound => f.write_str("Unbound"),
|
||||
Type::None => f.write_str("None"),
|
||||
Type::Module(module_id) => {
|
||||
write!(
|
||||
f,
|
||||
"<module '{:?}'>",
|
||||
module_id
|
||||
.scope
|
||||
.file(self.context.db)
|
||||
.path(self.context.db.upcast())
|
||||
)
|
||||
}
|
||||
// TODO functions and classes should display using a fully qualified name
|
||||
Type::Class(class_id) => {
|
||||
let class = class_id.lookup(self.context);
|
||||
|
||||
f.write_str("Literal[")?;
|
||||
f.write_str(class.name())?;
|
||||
f.write_str("]")
|
||||
}
|
||||
Type::Instance(class_id) => {
|
||||
let class = class_id.lookup(self.context);
|
||||
f.write_str(class.name())
|
||||
}
|
||||
Type::Function(function_id) => {
|
||||
let function = function_id.lookup(self.context);
|
||||
f.write_str(function.name())
|
||||
}
|
||||
Type::Union(union_id) => {
|
||||
let union = union_id.lookup(self.context);
|
||||
|
||||
union.display(self.context).fmt(f)
|
||||
}
|
||||
Type::Intersection(intersection_id) => {
|
||||
let intersection = intersection_id.lookup(self.context);
|
||||
|
||||
intersection.display(self.context).fmt(f)
|
||||
}
|
||||
Type::IntLiteral(n) => write!(f, "Literal[{n}]"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for DisplayType<'_> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
std::fmt::Display::fmt(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl UnionType {
|
||||
fn display<'a>(&'a self, context: &'a TypingContext<'a>) -> DisplayUnionType<'a> {
|
||||
DisplayUnionType { context, ty: self }
|
||||
}
|
||||
}
|
||||
|
||||
struct DisplayUnionType<'a> {
|
||||
ty: &'a UnionType,
|
||||
context: &'a TypingContext<'a>,
|
||||
}
|
||||
|
||||
impl Display for DisplayUnionType<'_> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
let union = self.ty;
|
||||
|
||||
let (int_literals, other_types): (Vec<Type>, Vec<Type>) = union
|
||||
.elements
|
||||
.iter()
|
||||
.copied()
|
||||
.partition(|ty| matches!(ty, Type::IntLiteral(_)));
|
||||
|
||||
let mut first = true;
|
||||
if !int_literals.is_empty() {
|
||||
f.write_str("Literal[")?;
|
||||
let mut nums: Vec<_> = int_literals
|
||||
.into_iter()
|
||||
.filter_map(|ty| {
|
||||
if let Type::IntLiteral(n) = ty {
|
||||
Some(n)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
nums.sort_unstable();
|
||||
for num in nums {
|
||||
if !first {
|
||||
f.write_str(", ")?;
|
||||
}
|
||||
write!(f, "{num}")?;
|
||||
first = false;
|
||||
}
|
||||
f.write_str("]")?;
|
||||
}
|
||||
|
||||
for ty in other_types {
|
||||
if !first {
|
||||
f.write_str(" | ")?;
|
||||
};
|
||||
first = false;
|
||||
write!(f, "{}", ty.display(self.context))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for DisplayUnionType<'_> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
std::fmt::Display::fmt(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntersectionType {
|
||||
fn display<'a>(&'a self, context: &'a TypingContext<'a>) -> DisplayIntersectionType<'a> {
|
||||
DisplayIntersectionType { ty: self, context }
|
||||
}
|
||||
}
|
||||
|
||||
struct DisplayIntersectionType<'a> {
|
||||
ty: &'a IntersectionType,
|
||||
context: &'a TypingContext<'a>,
|
||||
}
|
||||
|
||||
impl Display for DisplayIntersectionType<'_> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
let mut first = true;
|
||||
for (neg, ty) in self
|
||||
.ty
|
||||
.positive
|
||||
.iter()
|
||||
.map(|ty| (false, ty))
|
||||
.chain(self.ty.negative.iter().map(|ty| (true, ty)))
|
||||
{
|
||||
if !first {
|
||||
f.write_str(" & ")?;
|
||||
};
|
||||
first = false;
|
||||
if neg {
|
||||
f.write_str("~")?;
|
||||
};
|
||||
write!(f, "{}", ty.display(self.context))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for DisplayIntersectionType<'_> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
std::fmt::Display::fmt(self, f)
|
||||
}
|
||||
}
|
945
crates/ruff_python_semantic/src/red_knot/types/infer.rs
Normal file
945
crates/ruff_python_semantic/src/red_knot/types/infer.rs
Normal file
|
@ -0,0 +1,945 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use ruff_db::vfs::VfsFile;
|
||||
use ruff_index::IndexVec;
|
||||
use ruff_python_ast as ast;
|
||||
use ruff_python_ast::{ExprContext, TypeParams};
|
||||
|
||||
use crate::module::resolver::resolve_module;
|
||||
use crate::module::ModuleName;
|
||||
use crate::name::Name;
|
||||
use crate::red_knot::semantic_index::ast_ids::{ScopeAstIdNode, ScopeExpressionId};
|
||||
use crate::red_knot::semantic_index::definition::{
|
||||
Definition, ImportDefinition, ImportFromDefinition,
|
||||
};
|
||||
use crate::red_knot::semantic_index::symbol::{
|
||||
FileScopeId, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable,
|
||||
};
|
||||
use crate::red_knot::semantic_index::{symbol_table, ChildrenIter, SemanticIndex};
|
||||
use crate::red_knot::types::{
|
||||
ClassType, FunctionType, IntersectionType, ModuleType, ScopedClassTypeId, ScopedFunctionTypeId,
|
||||
ScopedIntersectionTypeId, ScopedUnionTypeId, Type, TypeId, TypingContext, UnionType,
|
||||
UnionTypeBuilder,
|
||||
};
|
||||
use crate::Db;
|
||||
|
||||
/// The inferred types for a single scope.
|
||||
#[derive(Debug, Eq, PartialEq, Default, Clone)]
|
||||
pub(crate) struct TypeInference {
|
||||
/// The type of the module if the scope is a module scope.
|
||||
module_type: Option<ModuleType>,
|
||||
|
||||
/// The types of the defined classes in this scope.
|
||||
class_types: IndexVec<ScopedClassTypeId, ClassType>,
|
||||
|
||||
/// The types of the defined functions in this scope.
|
||||
function_types: IndexVec<ScopedFunctionTypeId, FunctionType>,
|
||||
|
||||
union_types: IndexVec<ScopedUnionTypeId, UnionType>,
|
||||
intersection_types: IndexVec<ScopedIntersectionTypeId, IntersectionType>,
|
||||
|
||||
/// The types of every expression in this scope.
|
||||
expression_tys: IndexVec<ScopeExpressionId, Type>,
|
||||
|
||||
/// The public types of every symbol in this scope.
|
||||
symbol_tys: IndexVec<ScopedSymbolId, Type>,
|
||||
}
|
||||
|
||||
impl TypeInference {
|
||||
#[allow(unused)]
|
||||
pub(super) fn expression_ty(&self, expression: ScopeExpressionId) -> Type {
|
||||
self.expression_tys[expression]
|
||||
}
|
||||
|
||||
pub(super) fn symbol_ty(&self, symbol: ScopedSymbolId) -> Type {
|
||||
self.symbol_tys[symbol]
|
||||
}
|
||||
|
||||
pub(super) fn module_ty(&self) -> &ModuleType {
|
||||
self.module_type.as_ref().unwrap()
|
||||
}
|
||||
|
||||
pub(super) fn class_ty(&self, id: ScopedClassTypeId) -> &ClassType {
|
||||
&self.class_types[id]
|
||||
}
|
||||
|
||||
pub(super) fn function_ty(&self, id: ScopedFunctionTypeId) -> &FunctionType {
|
||||
&self.function_types[id]
|
||||
}
|
||||
|
||||
pub(super) fn union_ty(&self, id: ScopedUnionTypeId) -> &UnionType {
|
||||
&self.union_types[id]
|
||||
}
|
||||
|
||||
pub(super) fn intersection_ty(&self, id: ScopedIntersectionTypeId) -> &IntersectionType {
|
||||
&self.intersection_types[id]
|
||||
}
|
||||
|
||||
fn shrink_to_fit(&mut self) {
|
||||
self.class_types.shrink_to_fit();
|
||||
self.function_types.shrink_to_fit();
|
||||
self.union_types.shrink_to_fit();
|
||||
self.intersection_types.shrink_to_fit();
|
||||
|
||||
self.expression_tys.shrink_to_fit();
|
||||
self.symbol_tys.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder to infer all types in a [`ScopeId`].
|
||||
pub(super) struct TypeInferenceBuilder<'a> {
|
||||
db: &'a dyn Db,
|
||||
|
||||
// Cached lookups
|
||||
index: &'a SemanticIndex,
|
||||
scope: ScopeId,
|
||||
file_scope_id: FileScopeId,
|
||||
file_id: VfsFile,
|
||||
symbol_table: Arc<SymbolTable>,
|
||||
|
||||
/// The type inference results
|
||||
types: TypeInference,
|
||||
definition_tys: FxHashMap<Definition, Type>,
|
||||
children_scopes: ChildrenIter<'a>,
|
||||
}
|
||||
|
||||
impl<'a> TypeInferenceBuilder<'a> {
|
||||
/// Creates a new builder for inferring the types of `scope`.
|
||||
pub(super) fn new(db: &'a dyn Db, scope: ScopeId, index: &'a SemanticIndex) -> Self {
|
||||
let file_scope_id = scope.file_scope_id(db);
|
||||
let file = scope.file(db);
|
||||
let children_scopes = index.child_scopes(file_scope_id);
|
||||
let symbol_table = index.symbol_table(file_scope_id);
|
||||
|
||||
Self {
|
||||
index,
|
||||
file_scope_id,
|
||||
file_id: file,
|
||||
scope,
|
||||
symbol_table,
|
||||
|
||||
db,
|
||||
types: TypeInference::default(),
|
||||
definition_tys: FxHashMap::default(),
|
||||
children_scopes,
|
||||
}
|
||||
}
|
||||
|
||||
/// Infers the types of a `module`.
|
||||
pub(super) fn infer_module(&mut self, module: &ast::ModModule) {
|
||||
self.infer_body(&module.body);
|
||||
}
|
||||
|
||||
pub(super) fn infer_class_type_params(&mut self, class: &ast::StmtClassDef) {
|
||||
if let Some(type_params) = class.type_params.as_deref() {
|
||||
self.infer_type_parameters(type_params);
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn infer_class_body(&mut self, class: &ast::StmtClassDef) {
|
||||
self.infer_body(&class.body);
|
||||
}
|
||||
|
||||
pub(super) fn infer_function_type_params(&mut self, function: &ast::StmtFunctionDef) {
|
||||
if let Some(type_params) = function.type_params.as_deref() {
|
||||
self.infer_type_parameters(type_params);
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn infer_function_body(&mut self, function: &ast::StmtFunctionDef) {
|
||||
self.infer_body(&function.body);
|
||||
}
|
||||
|
||||
fn infer_body(&mut self, suite: &[ast::Stmt]) {
|
||||
for statement in suite {
|
||||
self.infer_statement(statement);
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_statement(&mut self, statement: &ast::Stmt) {
|
||||
match statement {
|
||||
ast::Stmt::FunctionDef(function) => self.infer_function_definition_statement(function),
|
||||
ast::Stmt::ClassDef(class) => self.infer_class_definition_statement(class),
|
||||
ast::Stmt::Expr(ast::StmtExpr { range: _, value }) => {
|
||||
self.infer_expression(value);
|
||||
}
|
||||
ast::Stmt::If(if_statement) => self.infer_if_statement(if_statement),
|
||||
ast::Stmt::Assign(assign) => self.infer_assignment_statement(assign),
|
||||
ast::Stmt::AnnAssign(assign) => self.infer_annotated_assignment_statement(assign),
|
||||
ast::Stmt::For(for_statement) => self.infer_for_statement(for_statement),
|
||||
ast::Stmt::Import(import) => self.infer_import_statement(import),
|
||||
ast::Stmt::ImportFrom(import) => self.infer_import_from_statement(import),
|
||||
ast::Stmt::Break(_) | ast::Stmt::Continue(_) | ast::Stmt::Pass(_) => {
|
||||
// No-op
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_function_definition_statement(&mut self, function: &ast::StmtFunctionDef) {
|
||||
let ast::StmtFunctionDef {
|
||||
range: _,
|
||||
is_async: _,
|
||||
name,
|
||||
type_params: _,
|
||||
parameters: _,
|
||||
returns,
|
||||
body: _,
|
||||
decorator_list,
|
||||
} = function;
|
||||
|
||||
let function_id = function.scope_ast_id(self.db, self.file_id, self.file_scope_id);
|
||||
let decorator_tys = decorator_list
|
||||
.iter()
|
||||
.map(|decorator| self.infer_decorator(decorator))
|
||||
.collect();
|
||||
|
||||
// TODO: Infer parameters
|
||||
|
||||
if let Some(return_ty) = returns {
|
||||
self.infer_expression(return_ty);
|
||||
}
|
||||
|
||||
let function_ty = self.function_ty(FunctionType {
|
||||
name: Name::new(&name.id),
|
||||
decorators: decorator_tys,
|
||||
});
|
||||
|
||||
// Skip over the function or type params child scope.
|
||||
let (_, scope) = self.children_scopes.next().unwrap();
|
||||
|
||||
assert!(matches!(
|
||||
scope.kind(),
|
||||
ScopeKind::Function | ScopeKind::Annotation
|
||||
));
|
||||
|
||||
self.definition_tys
|
||||
.insert(Definition::FunctionDef(function_id), function_ty);
|
||||
}
|
||||
|
||||
fn infer_class_definition_statement(&mut self, class: &ast::StmtClassDef) {
|
||||
let ast::StmtClassDef {
|
||||
range: _,
|
||||
name,
|
||||
type_params,
|
||||
decorator_list,
|
||||
arguments,
|
||||
body: _,
|
||||
} = class;
|
||||
|
||||
let class_id = class.scope_ast_id(self.db, self.file_id, self.file_scope_id);
|
||||
|
||||
for decorator in decorator_list {
|
||||
self.infer_decorator(decorator);
|
||||
}
|
||||
|
||||
let bases = arguments
|
||||
.as_deref()
|
||||
.map(|arguments| self.infer_arguments(arguments))
|
||||
.unwrap_or(Vec::new());
|
||||
|
||||
// If the class has type parameters, then the class body scope is the first child scope of the type parameter's scope
|
||||
// Otherwise the next scope must be the class definition scope.
|
||||
let (class_body_scope_id, class_body_scope) = if type_params.is_some() {
|
||||
let (type_params_scope, _) = self.children_scopes.next().unwrap();
|
||||
self.index.child_scopes(type_params_scope).next().unwrap()
|
||||
} else {
|
||||
self.children_scopes.next().unwrap()
|
||||
};
|
||||
|
||||
assert_eq!(class_body_scope.kind(), ScopeKind::Class);
|
||||
|
||||
let class_ty = self.class_ty(ClassType {
|
||||
name: Name::new(name),
|
||||
bases,
|
||||
body_scope: class_body_scope_id.to_scope_id(self.db, self.file_id),
|
||||
});
|
||||
|
||||
self.definition_tys
|
||||
.insert(Definition::ClassDef(class_id), class_ty);
|
||||
}
|
||||
|
||||
fn infer_if_statement(&mut self, if_statement: &ast::StmtIf) {
|
||||
let ast::StmtIf {
|
||||
range: _,
|
||||
test,
|
||||
body,
|
||||
elif_else_clauses,
|
||||
} = if_statement;
|
||||
|
||||
self.infer_expression(test);
|
||||
self.infer_body(body);
|
||||
|
||||
for clause in elif_else_clauses {
|
||||
let ast::ElifElseClause {
|
||||
range: _,
|
||||
test,
|
||||
body,
|
||||
} = clause;
|
||||
|
||||
if let Some(test) = &test {
|
||||
self.infer_expression(test);
|
||||
}
|
||||
|
||||
self.infer_body(body);
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_assignment_statement(&mut self, assignment: &ast::StmtAssign) {
|
||||
let ast::StmtAssign {
|
||||
range: _,
|
||||
targets,
|
||||
value,
|
||||
} = assignment;
|
||||
|
||||
let value_ty = self.infer_expression(value);
|
||||
|
||||
for target in targets {
|
||||
self.infer_expression(target);
|
||||
}
|
||||
|
||||
let assign_id = assignment.scope_ast_id(self.db, self.file_id, self.file_scope_id);
|
||||
|
||||
// TODO: Handle multiple targets.
|
||||
self.definition_tys
|
||||
.insert(Definition::Assignment(assign_id), value_ty);
|
||||
}
|
||||
|
||||
fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) {
|
||||
let ast::StmtAnnAssign {
|
||||
range: _,
|
||||
target,
|
||||
annotation,
|
||||
value,
|
||||
simple: _,
|
||||
} = assignment;
|
||||
|
||||
if let Some(value) = value {
|
||||
let _ = self.infer_expression(value);
|
||||
}
|
||||
|
||||
let annotation_ty = self.infer_expression(annotation);
|
||||
self.infer_expression(target);
|
||||
|
||||
self.definition_tys.insert(
|
||||
Definition::AnnotatedAssignment(assignment.scope_ast_id(
|
||||
self.db,
|
||||
self.file_id,
|
||||
self.file_scope_id,
|
||||
)),
|
||||
annotation_ty,
|
||||
);
|
||||
}
|
||||
|
||||
fn infer_for_statement(&mut self, for_statement: &ast::StmtFor) {
|
||||
let ast::StmtFor {
|
||||
range: _,
|
||||
target,
|
||||
iter,
|
||||
body,
|
||||
orelse,
|
||||
is_async: _,
|
||||
} = for_statement;
|
||||
|
||||
self.infer_expression(iter);
|
||||
self.infer_expression(target);
|
||||
self.infer_body(body);
|
||||
self.infer_body(orelse);
|
||||
}
|
||||
|
||||
fn infer_import_statement(&mut self, import: &ast::StmtImport) {
|
||||
let ast::StmtImport { range: _, names } = import;
|
||||
|
||||
let import_id = import.scope_ast_id(self.db, self.file_id, self.file_scope_id);
|
||||
|
||||
for (i, alias) in names.iter().enumerate() {
|
||||
let ast::Alias {
|
||||
range: _,
|
||||
name,
|
||||
asname: _,
|
||||
} = alias;
|
||||
|
||||
let module_name = ModuleName::new(&name.id);
|
||||
let module = module_name.and_then(|name| resolve_module(self.db, name));
|
||||
let module_ty = module
|
||||
.map(|module| self.typing_context().module_ty(module.file()))
|
||||
.unwrap_or(Type::Unknown);
|
||||
|
||||
self.definition_tys.insert(
|
||||
Definition::Import(ImportDefinition {
|
||||
import_id,
|
||||
alias: u32::try_from(i).unwrap(),
|
||||
}),
|
||||
module_ty,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_import_from_statement(&mut self, import: &ast::StmtImportFrom) {
|
||||
let ast::StmtImportFrom {
|
||||
range: _,
|
||||
module,
|
||||
names,
|
||||
level: _,
|
||||
} = import;
|
||||
|
||||
let import_id = import.scope_ast_id(self.db, self.file_id, self.file_scope_id);
|
||||
let module_name = ModuleName::new(module.as_deref().expect("Support relative imports"));
|
||||
|
||||
let module = module_name.and_then(|module_name| resolve_module(self.db, module_name));
|
||||
let module_ty = module
|
||||
.map(|module| self.typing_context().module_ty(module.file()))
|
||||
.unwrap_or(Type::Unknown);
|
||||
|
||||
for (i, alias) in names.iter().enumerate() {
|
||||
let ast::Alias {
|
||||
range: _,
|
||||
name,
|
||||
asname: _,
|
||||
} = alias;
|
||||
|
||||
let ty = module_ty
|
||||
.member(&self.typing_context(), &Name::new(&name.id))
|
||||
.unwrap_or(Type::Unknown);
|
||||
|
||||
self.definition_tys.insert(
|
||||
Definition::ImportFrom(ImportFromDefinition {
|
||||
import_id,
|
||||
name: u32::try_from(i).unwrap(),
|
||||
}),
|
||||
ty,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_decorator(&mut self, decorator: &ast::Decorator) -> Type {
|
||||
let ast::Decorator {
|
||||
range: _,
|
||||
expression,
|
||||
} = decorator;
|
||||
|
||||
self.infer_expression(expression)
|
||||
}
|
||||
|
||||
fn infer_arguments(&mut self, arguments: &ast::Arguments) -> Vec<Type> {
|
||||
let mut types = Vec::with_capacity(
|
||||
arguments
|
||||
.args
|
||||
.len()
|
||||
.saturating_add(arguments.keywords.len()),
|
||||
);
|
||||
|
||||
types.extend(arguments.args.iter().map(|arg| self.infer_expression(arg)));
|
||||
|
||||
types.extend(arguments.keywords.iter().map(
|
||||
|ast::Keyword {
|
||||
range: _,
|
||||
arg: _,
|
||||
value,
|
||||
}| self.infer_expression(value),
|
||||
));
|
||||
|
||||
types
|
||||
}
|
||||
|
||||
fn infer_expression(&mut self, expression: &ast::Expr) -> Type {
|
||||
let ty = match expression {
|
||||
ast::Expr::NoneLiteral(ast::ExprNoneLiteral { range: _ }) => Type::None,
|
||||
ast::Expr::NumberLiteral(literal) => self.infer_number_literal_expression(literal),
|
||||
ast::Expr::Name(name) => self.infer_name_expression(name),
|
||||
ast::Expr::Attribute(attribute) => self.infer_attribute_expression(attribute),
|
||||
ast::Expr::BinOp(binary) => self.infer_binary_expression(binary),
|
||||
ast::Expr::Named(named) => self.infer_named_expression(named),
|
||||
ast::Expr::If(if_expression) => self.infer_if_expression(if_expression),
|
||||
|
||||
_ => todo!("expression type resolution for {:?}", expression),
|
||||
};
|
||||
|
||||
self.types.expression_tys.push(ty);
|
||||
|
||||
ty
|
||||
}
|
||||
|
||||
#[allow(clippy::unused_self)]
|
||||
fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type {
|
||||
let ast::ExprNumberLiteral { range: _, value } = literal;
|
||||
|
||||
match value {
|
||||
ast::Number::Int(n) => {
|
||||
// TODO support big int literals
|
||||
n.as_i64().map(Type::IntLiteral).unwrap_or(Type::Unknown)
|
||||
}
|
||||
// TODO builtins.float or builtins.complex
|
||||
_ => Type::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_named_expression(&mut self, named: &ast::ExprNamed) -> Type {
|
||||
let ast::ExprNamed {
|
||||
range: _,
|
||||
target,
|
||||
value,
|
||||
} = named;
|
||||
|
||||
let value_ty = self.infer_expression(value);
|
||||
self.infer_expression(target);
|
||||
|
||||
self.definition_tys.insert(
|
||||
Definition::NamedExpr(named.scope_ast_id(self.db, self.file_id, self.file_scope_id)),
|
||||
value_ty,
|
||||
);
|
||||
|
||||
value_ty
|
||||
}
|
||||
|
||||
fn infer_if_expression(&mut self, if_expression: &ast::ExprIf) -> Type {
|
||||
let ast::ExprIf {
|
||||
range: _,
|
||||
test,
|
||||
body,
|
||||
orelse,
|
||||
} = if_expression;
|
||||
|
||||
self.infer_expression(test);
|
||||
|
||||
// TODO detect statically known truthy or falsy test
|
||||
let body_ty = self.infer_expression(body);
|
||||
let orelse_ty = self.infer_expression(orelse);
|
||||
|
||||
let union = UnionTypeBuilder::new(&self.typing_context())
|
||||
.add(body_ty)
|
||||
.add(orelse_ty)
|
||||
.build();
|
||||
|
||||
self.union_ty(union)
|
||||
}
|
||||
|
||||
fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type {
|
||||
let ast::ExprName { range: _, id, ctx } = name;
|
||||
|
||||
match ctx {
|
||||
ExprContext::Load => {
|
||||
if let Some(symbol_id) = self
|
||||
.index
|
||||
.symbol_table(self.file_scope_id)
|
||||
.symbol_id_by_name(id)
|
||||
{
|
||||
self.local_definition_ty(symbol_id)
|
||||
} else {
|
||||
let ancestors = self.index.ancestor_scopes(self.file_scope_id).skip(1);
|
||||
|
||||
for (ancestor_id, _) in ancestors {
|
||||
// TODO: Skip over class scopes unless the they are a immediately-nested type param scope.
|
||||
// TODO: Support built-ins
|
||||
|
||||
let symbol_table =
|
||||
symbol_table(self.db, ancestor_id.to_scope_id(self.db, self.file_id));
|
||||
|
||||
if let Some(_symbol_id) = symbol_table.symbol_id_by_name(id) {
|
||||
todo!("Return type for symbol from outer scope");
|
||||
}
|
||||
}
|
||||
Type::Unknown
|
||||
}
|
||||
}
|
||||
ExprContext::Del => Type::None,
|
||||
ExprContext::Invalid => Type::Unknown,
|
||||
ExprContext::Store => Type::None,
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_attribute_expression(&mut self, attribute: &ast::ExprAttribute) -> Type {
|
||||
let ast::ExprAttribute {
|
||||
value,
|
||||
attr,
|
||||
range: _,
|
||||
ctx,
|
||||
} = attribute;
|
||||
|
||||
let value_ty = self.infer_expression(value);
|
||||
let member_ty = value_ty
|
||||
.member(&self.typing_context(), &Name::new(&attr.id))
|
||||
.unwrap_or(Type::Unknown);
|
||||
|
||||
match ctx {
|
||||
ExprContext::Load => member_ty,
|
||||
ExprContext::Store | ExprContext::Del => Type::None,
|
||||
ExprContext::Invalid => Type::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_binary_expression(&mut self, binary: &ast::ExprBinOp) -> Type {
|
||||
let ast::ExprBinOp {
|
||||
left,
|
||||
op,
|
||||
right,
|
||||
range: _,
|
||||
} = binary;
|
||||
|
||||
let left_ty = self.infer_expression(left);
|
||||
let right_ty = self.infer_expression(right);
|
||||
|
||||
match left_ty {
|
||||
Type::Any => Type::Any,
|
||||
Type::Unknown => Type::Unknown,
|
||||
Type::IntLiteral(n) => {
|
||||
match right_ty {
|
||||
Type::IntLiteral(m) => {
|
||||
match op {
|
||||
ast::Operator::Add => n
|
||||
.checked_add(m)
|
||||
.map(Type::IntLiteral)
|
||||
// TODO builtins.int
|
||||
.unwrap_or(Type::Unknown),
|
||||
ast::Operator::Sub => n
|
||||
.checked_sub(m)
|
||||
.map(Type::IntLiteral)
|
||||
// TODO builtins.int
|
||||
.unwrap_or(Type::Unknown),
|
||||
ast::Operator::Mult => n
|
||||
.checked_mul(m)
|
||||
.map(Type::IntLiteral)
|
||||
// TODO builtins.int
|
||||
.unwrap_or(Type::Unknown),
|
||||
ast::Operator::Div => n
|
||||
.checked_div(m)
|
||||
.map(Type::IntLiteral)
|
||||
// TODO builtins.int
|
||||
.unwrap_or(Type::Unknown),
|
||||
ast::Operator::Mod => n
|
||||
.checked_rem(m)
|
||||
.map(Type::IntLiteral)
|
||||
// TODO division by zero error
|
||||
.unwrap_or(Type::Unknown),
|
||||
_ => todo!("complete binop op support for IntLiteral"),
|
||||
}
|
||||
}
|
||||
_ => todo!("complete binop right_ty support for IntLiteral"),
|
||||
}
|
||||
}
|
||||
_ => todo!("complete binop support"),
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_type_parameters(&mut self, _type_parameters: &TypeParams) {
|
||||
todo!("Infer type parameters")
|
||||
}
|
||||
|
||||
pub(super) fn finish(mut self) -> TypeInference {
|
||||
let symbol_tys: IndexVec<_, _> = self
|
||||
.index
|
||||
.symbol_table(self.file_scope_id)
|
||||
.symbol_ids()
|
||||
.map(|symbol| self.local_definition_ty(symbol))
|
||||
.collect();
|
||||
|
||||
self.types.symbol_tys = symbol_tys;
|
||||
self.types.shrink_to_fit();
|
||||
self.types
|
||||
}
|
||||
|
||||
fn union_ty(&mut self, ty: UnionType) -> Type {
|
||||
Type::Union(TypeId {
|
||||
scope: self.scope,
|
||||
scoped: self.types.union_types.push(ty),
|
||||
})
|
||||
}
|
||||
|
||||
fn function_ty(&mut self, ty: FunctionType) -> Type {
|
||||
Type::Function(TypeId {
|
||||
scope: self.scope,
|
||||
scoped: self.types.function_types.push(ty),
|
||||
})
|
||||
}
|
||||
|
||||
fn class_ty(&mut self, ty: ClassType) -> Type {
|
||||
Type::Class(TypeId {
|
||||
scope: self.scope,
|
||||
scoped: self.types.class_types.push(ty),
|
||||
})
|
||||
}
|
||||
|
||||
fn typing_context(&self) -> TypingContext {
|
||||
TypingContext::scoped(self.db, self.scope, &self.types)
|
||||
}
|
||||
|
||||
fn local_definition_ty(&mut self, symbol: ScopedSymbolId) -> Type {
|
||||
let symbol = self.symbol_table.symbol(symbol);
|
||||
let mut definitions = symbol
|
||||
.definitions()
|
||||
.iter()
|
||||
.filter_map(|definition| self.definition_tys.get(definition).copied());
|
||||
|
||||
let Some(first) = definitions.next() else {
|
||||
return Type::Unbound;
|
||||
};
|
||||
|
||||
if let Some(second) = definitions.next() {
|
||||
let context = self.typing_context();
|
||||
let mut builder = UnionTypeBuilder::new(&context);
|
||||
builder = builder.add(first).add(second);
|
||||
|
||||
for variant in definitions {
|
||||
builder = builder.add(variant);
|
||||
}
|
||||
|
||||
self.union_ty(builder.build())
|
||||
} else {
|
||||
first
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use ruff_db::file_system::FileSystemPathBuf;
|
||||
use ruff_db::vfs::system_path_to_file;
|
||||
|
||||
use crate::db::tests::TestDb;
|
||||
use crate::module::resolver::{set_module_resolution_settings, ModuleResolutionSettings};
|
||||
use crate::name::Name;
|
||||
use crate::red_knot::types::{public_symbol_ty_by_name, Type, TypingContext};
|
||||
|
||||
fn setup_db() -> TestDb {
|
||||
let mut db = TestDb::new();
|
||||
|
||||
set_module_resolution_settings(
|
||||
&mut db,
|
||||
ModuleResolutionSettings {
|
||||
extra_paths: Vec::new(),
|
||||
workspace_root: FileSystemPathBuf::from("/src"),
|
||||
site_packages: None,
|
||||
custom_typeshed: None,
|
||||
},
|
||||
);
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
fn assert_public_ty(db: &TestDb, file_name: &str, symbol_name: &str, expected: &str) {
|
||||
let file = system_path_to_file(db, file_name).expect("Expected file to exist.");
|
||||
|
||||
let ty = public_symbol_ty_by_name(db, file, symbol_name).unwrap_or(Type::Unknown);
|
||||
assert_eq!(ty.display(&TypingContext::global(db)).to_string(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn follow_import_to_class() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_files([
|
||||
("src/a.py", "from b import C as D; E = D"),
|
||||
("src/b.py", "class C: pass"),
|
||||
])?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "E", "Literal[C]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_base_class_by_name() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_file(
|
||||
"src/mod.py",
|
||||
r#"
|
||||
class Base:
|
||||
pass
|
||||
|
||||
class Sub(Base):
|
||||
pass"#,
|
||||
)?;
|
||||
|
||||
let mod_file = system_path_to_file(&db, "src/mod.py").expect("Expected file to exist.");
|
||||
let ty = public_symbol_ty_by_name(&db, mod_file, "Sub").expect("Symbol type to exist");
|
||||
|
||||
let Type::Class(class_id) = ty else {
|
||||
panic!("Sub is not a Class")
|
||||
};
|
||||
|
||||
let context = TypingContext::global(&db);
|
||||
|
||||
let base_names: Vec<_> = class_id
|
||||
.lookup(&context)
|
||||
.bases()
|
||||
.iter()
|
||||
.map(|base_ty| format!("{}", base_ty.display(&context)))
|
||||
.collect();
|
||||
|
||||
assert_eq!(base_names, vec!["Literal[Base]"]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_method() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_file(
|
||||
"src/mod.py",
|
||||
"
|
||||
class C:
|
||||
def f(self): pass
|
||||
",
|
||||
)?;
|
||||
|
||||
let mod_file = system_path_to_file(&db, "src/mod.py").unwrap();
|
||||
let ty = public_symbol_ty_by_name(&db, mod_file, "C").unwrap();
|
||||
|
||||
let Type::Class(class_id) = ty else {
|
||||
panic!("C is not a Class");
|
||||
};
|
||||
|
||||
let context = TypingContext::global(&db);
|
||||
let member_ty = class_id.class_member(&context, &Name::new("f"));
|
||||
|
||||
let Some(Type::Function(func_id)) = member_ty else {
|
||||
panic!("C.f is not a Function");
|
||||
};
|
||||
|
||||
let function_ty = func_id.lookup(&context);
|
||||
assert_eq!(function_ty.name(), "f");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_module_member() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_files([
|
||||
("src/a.py", "import b; D = b.C"),
|
||||
("src/b.py", "class C: pass"),
|
||||
])?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "D", "Literal[C]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_literal() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_file("src/a.py", "x = 1")?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "x", "Literal[1]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_union() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_file(
|
||||
"src/a.py",
|
||||
"
|
||||
if flag:
|
||||
x = 1
|
||||
else:
|
||||
x = 2
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn literal_int_arithmetic() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_file(
|
||||
"src/a.py",
|
||||
"
|
||||
a = 2 + 1
|
||||
b = a - 4
|
||||
c = a * b
|
||||
d = c / 3
|
||||
e = 5 % 3
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "a", "Literal[3]");
|
||||
assert_public_ty(&db, "src/a.py", "b", "Literal[-1]");
|
||||
assert_public_ty(&db, "src/a.py", "c", "Literal[-3]");
|
||||
assert_public_ty(&db, "src/a.py", "d", "Literal[-1]");
|
||||
assert_public_ty(&db, "src/a.py", "e", "Literal[2]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn walrus() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system()
|
||||
.write_file("src/a.py", "x = (y := 1) + 1")?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "x", "Literal[2]");
|
||||
assert_public_ty(&db, "src/a.py", "y", "Literal[1]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ifexpr() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system()
|
||||
.write_file("src/a.py", "x = 1 if flag else 2")?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ifexpr_walrus() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system().write_file(
|
||||
"src/a.py",
|
||||
"
|
||||
y = z = 0
|
||||
x = (y := 1) if flag else (z := 2)
|
||||
a = y
|
||||
b = z
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]");
|
||||
assert_public_ty(&db, "src/a.py", "a", "Literal[0, 1]");
|
||||
assert_public_ty(&db, "src/a.py", "b", "Literal[0, 2]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ifexpr_nested() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system()
|
||||
.write_file("src/a.py", "x = 1 if flag else 2 if flag2 else 3")?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2, 3]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn none() -> anyhow::Result<()> {
|
||||
let db = setup_db();
|
||||
|
||||
db.memory_file_system()
|
||||
.write_file("src/a.py", "x = 1 if flag else None")?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "x", "Literal[1] | None");
|
||||
Ok(())
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue