red-knot(Salsa): Types without refinements (#11899)

This commit is contained in:
Micha Reiser 2024-06-20 11:49:38 +01:00 committed by GitHub
parent a26bd01be2
commit 22733cb7c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 2169 additions and 147 deletions

View file

@ -20,6 +20,7 @@ ruff_text_size = { workspace = true }
bitflags = { workspace = true }
is-macro = { workspace = true }
indexmap = { workspace = true, optional = true }
salsa = { workspace = true, optional = true }
smallvec = { workspace = true, optional = true }
smol_str = { workspace = true }
@ -36,4 +37,4 @@ tempfile = { workspace = true }
workspace = true
[features]
red_knot = ["dep:salsa", "dep:tracing", "dep:hashbrown", "dep:smallvec"]
red_knot = ["dep:salsa", "dep:tracing", "dep:hashbrown", "dep:smallvec", "dep:indexmap"]

View file

@ -6,20 +6,27 @@ use crate::module::resolver::{
file_to_module, internal::ModuleNameIngredient, internal::ModuleResolverSearchPaths,
resolve_module_query,
};
use crate::red_knot::semantic_index::symbol::ScopeId;
use crate::red_knot::semantic_index::{scopes_map, semantic_index, symbol_table};
use crate::red_knot::semantic_index::symbol::{
public_symbols_map, scopes_map, PublicSymbolId, ScopeId,
};
use crate::red_knot::semantic_index::{root_scope, semantic_index, symbol_table};
use crate::red_knot::types::{infer_types, public_symbol_ty};
#[salsa::jar(db=Db)]
pub struct Jar(
ModuleNameIngredient,
ModuleResolverSearchPaths,
ScopeId,
PublicSymbolId,
symbol_table,
resolve_module_query,
file_to_module,
scopes_map,
root_scope,
semantic_index,
infer_types,
public_symbol_ty,
public_symbols_map,
);
/// Database giving access to semantic information about a Python program.
@ -27,9 +34,13 @@ pub trait Db: SourceDb + DbWithJar<Jar> + Upcast<dyn SourceDb> {}
#[cfg(test)]
pub(crate) mod tests {
use std::fmt::Formatter;
use std::marker::PhantomData;
use std::sync::Arc;
use salsa::DebugWithDb;
use salsa::ingredient::Ingredient;
use salsa::storage::HasIngredientsFor;
use salsa::{AsId, DebugWithDb};
use ruff_db::file_system::{FileSystem, MemoryFileSystem, OsFileSystem};
use ruff_db::vfs::Vfs;
@ -86,7 +97,7 @@ pub(crate) mod tests {
///
/// ## Panics
/// If there are any pending salsa snapshots.
pub(crate) fn take_sale_events(&mut self) -> Vec<salsa::Event> {
pub(crate) fn take_salsa_events(&mut self) -> Vec<salsa::Event> {
let inner = Arc::get_mut(&mut self.events).expect("no pending salsa snapshots");
let events = inner.get_mut().unwrap();
@ -98,7 +109,7 @@ pub(crate) mod tests {
/// ## Panics
/// If there are any pending salsa snapshots.
pub(crate) fn clear_salsa_events(&mut self) {
self.take_sale_events();
self.take_salsa_events();
}
}
@ -150,4 +161,106 @@ pub(crate) mod tests {
#[allow(unused)]
Os(OsFileSystem),
}
pub(crate) fn assert_will_run_function_query<C, Db, Jar>(
db: &Db,
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
key: C::Key,
events: &[salsa::Event],
) where
C: salsa::function::Configuration<Jar = Jar>
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
Jar: HasIngredientsFor<C>,
Db: salsa::DbWithJar<Jar>,
C::Key: AsId,
{
will_run_function_query(db, to_function, key, events, true);
}
pub(crate) fn assert_will_not_run_function_query<C, Db, Jar>(
db: &Db,
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
key: C::Key,
events: &[salsa::Event],
) where
C: salsa::function::Configuration<Jar = Jar>
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
Jar: HasIngredientsFor<C>,
Db: salsa::DbWithJar<Jar>,
C::Key: AsId,
{
will_run_function_query(db, to_function, key, events, false);
}
fn will_run_function_query<C, Db, Jar>(
db: &Db,
to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient<C>,
key: C::Key,
events: &[salsa::Event],
should_run: bool,
) where
C: salsa::function::Configuration<Jar = Jar>
+ salsa::storage::IngredientsFor<Jar = Jar, Ingredients = C>,
Jar: HasIngredientsFor<C>,
Db: salsa::DbWithJar<Jar>,
C::Key: AsId,
{
let (jar, _) =
<_ as salsa::storage::HasJar<<C as salsa::storage::IngredientsFor>::Jar>>::jar(db);
let ingredient = jar.ingredient();
let function_ingredient = to_function(ingredient);
let ingredient_index =
<salsa::function::FunctionIngredient<C> as Ingredient<Db>>::ingredient_index(
function_ingredient,
);
let did_run = events.iter().any(|event| {
if let salsa::EventKind::WillExecute { database_key } = event.kind {
database_key.ingredient_index() == ingredient_index
&& database_key.key_index() == key.as_id()
} else {
false
}
});
if should_run && !did_run {
panic!(
"Expected query {:?} to run but it didn't",
DebugIdx {
db: PhantomData::<Db>,
value_id: key.as_id(),
ingredient: function_ingredient,
}
);
} else if !should_run && did_run {
panic!(
"Expected query {:?} not to run but it did",
DebugIdx {
db: PhantomData::<Db>,
value_id: key.as_id(),
ingredient: function_ingredient,
}
);
}
}
struct DebugIdx<'a, I, Db>
where
I: Ingredient<Db>,
{
value_id: salsa::Id,
ingredient: &'a I,
db: PhantomData<Db>,
}
impl<'a, I, Db> std::fmt::Debug for DebugIdx<'a, I, Db>
where
I: Ingredient<Db>,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.ingredient.fmt_index(Some(self.value_id), f)
}
}
}

View file

@ -886,7 +886,7 @@ mod tests {
let foo_module2 = resolve_module(&db, foo_module_name);
assert!(!db
.take_sale_events()
.take_salsa_events()
.iter()
.any(|event| { matches!(event.kind, salsa::EventKind::WillExecute { .. }) }));

View file

@ -1,3 +1,8 @@
use rustc_hash::FxHasher;
use std::hash::BuildHasherDefault;
pub mod ast_node_ref;
mod node_key;
pub mod semantic_index;
pub mod types;
pub(crate) type FxIndexSet<V> = indexmap::set::IndexSet<V, BuildHasherDefault<FxHasher>>;

View file

@ -9,10 +9,10 @@ use ruff_index::{IndexSlice, IndexVec};
use ruff_python_ast as ast;
use crate::red_knot::node_key::NodeKey;
use crate::red_knot::semantic_index::ast_ids::AstIds;
use crate::red_knot::semantic_index::ast_ids::{AstId, AstIds, ScopeClassId, ScopeFunctionId};
use crate::red_knot::semantic_index::builder::SemanticIndexBuilder;
use crate::red_knot::semantic_index::symbol::{
FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeSymbolId, ScopesMap, SymbolTable,
FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable,
};
use crate::Db;
@ -21,7 +21,7 @@ mod builder;
pub mod definition;
pub mod symbol;
type SymbolMap = hashbrown::HashMap<ScopeSymbolId, (), ()>;
type SymbolMap = hashbrown::HashMap<ScopedSymbolId, (), ()>;
/// Returns the semantic index for `file`.
///
@ -42,33 +42,22 @@ pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex {
pub(crate) fn symbol_table(db: &dyn Db, scope: ScopeId) -> Arc<SymbolTable> {
let index = semantic_index(db, scope.file(db));
index.symbol_table(scope.scope_id(db))
}
/// Returns a mapping from file specific [`FileScopeId`] to a program-wide unique [`ScopeId`].
#[salsa::tracked(return_ref)]
pub(crate) fn scopes_map(db: &dyn Db, file: VfsFile) -> ScopesMap {
let index = semantic_index(db, file);
let scopes: IndexVec<_, _> = index
.scopes
.indices()
.map(|id| ScopeId::new(db, file, id))
.collect();
ScopesMap::new(scopes)
index.symbol_table(scope.file_scope_id(db))
}
/// Returns the root scope of `file`.
pub fn root_scope(db: &dyn Db, file: VfsFile) -> ScopeId {
#[salsa::tracked]
pub(crate) fn root_scope(db: &dyn Db, file: VfsFile) -> ScopeId {
FileScopeId::root().to_scope_id(db, file)
}
/// Returns the symbol with the given name in `file`'s public scope or `None` if
/// no symbol with the given name exists.
pub fn global_symbol(db: &dyn Db, file: VfsFile, name: &str) -> Option<PublicSymbolId> {
pub fn public_symbol(db: &dyn Db, file: VfsFile, name: &str) -> Option<PublicSymbolId> {
let root_scope = root_scope(db, file);
root_scope.symbol(db, name)
let symbol_table = symbol_table(db, root_scope);
let local = symbol_table.symbol_id_by_name(name)?;
Some(local.to_public_symbol(db, file))
}
/// The symbol tables for an entire file.
@ -90,6 +79,9 @@ pub struct SemanticIndex {
/// Note: We should not depend on this map when analysing other files or
/// changing a file invalidates all dependents.
ast_ids: IndexVec<FileScopeId, AstIds>,
/// Map from scope to the node that introduces the scope.
scope_nodes: IndexVec<FileScopeId, NodeWithScopeId>,
}
impl SemanticIndex {
@ -97,7 +89,7 @@ impl SemanticIndex {
///
/// Use the Salsa cached [`symbol_table`] query if you only need the
/// symbol table for a single scope.
fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable> {
pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable> {
self.symbol_tables[scope_id].clone()
}
@ -152,6 +144,10 @@ impl SemanticIndex {
pub(crate) fn ancestor_scopes(&self, scope: FileScopeId) -> AncestorsIter {
AncestorsIter::new(self, scope)
}
pub(crate) fn scope_node(&self, scope_id: FileScopeId) -> NodeWithScopeId {
self.scope_nodes[scope_id]
}
}
/// ID that uniquely identifies an expression inside a [`Scope`].
@ -246,6 +242,28 @@ impl<'a> Iterator for ChildrenIter<'a> {
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub(crate) enum NodeWithScopeId {
Module,
Class(AstId<ScopeClassId>),
ClassTypeParams(AstId<ScopeClassId>),
Function(AstId<ScopeFunctionId>),
FunctionTypeParams(AstId<ScopeFunctionId>),
}
impl NodeWithScopeId {
fn scope_kind(self) -> ScopeKind {
match self {
NodeWithScopeId::Module => ScopeKind::Module,
NodeWithScopeId::Class(_) => ScopeKind::Class,
NodeWithScopeId::Function(_) => ScopeKind::Function,
NodeWithScopeId::ClassTypeParams(_) | NodeWithScopeId::FunctionTypeParams(_) => {
ScopeKind::Annotation
}
}
}
}
impl FusedIterator for ChildrenIter<'_> {}
#[cfg(test)]
@ -583,19 +601,14 @@ class C[T]:
let TestCase { db, file } = test_case("x = 1;\ndef test():\n y = 4");
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
let parsed = parsed_module(&db, file);
let ast = parsed.syntax();
let x_sym = root_table
.symbol_by_name("x")
.expect("x symbol should exist");
let x_stmt = ast.body[0].as_assign_stmt().unwrap();
let x = &x_stmt.targets[0];
assert_eq!(index.expression_scope(x).kind(), ScopeKind::Module);
assert_eq!(index.expression_scope_id(x), x_sym.scope());
assert_eq!(index.expression_scope_id(x), FileScopeId::root());
let def = ast.body[1].as_function_def_stmt().unwrap();
let y_stmt = def.body[0].as_assign_stmt().unwrap();

View file

@ -66,13 +66,13 @@ impl std::fmt::Debug for AstIds {
}
fn ast_ids(db: &dyn Db, scope: ScopeId) -> &AstIds {
semantic_index(db, scope.file(db)).ast_ids(scope.scope_id(db))
semantic_index(db, scope.file(db)).ast_ids(scope.file_scope_id(db))
}
/// Node that can be uniquely identified by an id in a [`FileScopeId`].
pub trait ScopeAstIdNode {
/// The type of the ID uniquely identifying the node.
type Id;
type Id: Copy;
/// Returns the ID that uniquely identifies the node in `scope`.
///
@ -91,7 +91,7 @@ pub trait ScopeAstIdNode {
/// Extension trait for AST nodes that can be resolved by an `AstId`.
pub trait AstIdNode {
type ScopeId;
type ScopeId: Copy;
/// Resolves the AST id of the node.
///
@ -133,7 +133,7 @@ where
/// Uniquely identifies an AST node in a file.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct AstId<L> {
pub struct AstId<L: Copy> {
/// The node's scope.
scope: FileScopeId,
@ -141,6 +141,16 @@ pub struct AstId<L> {
in_scope_id: L,
}
impl<L: Copy> AstId<L> {
pub(super) fn new(scope: FileScopeId, in_scope_id: L) -> Self {
Self { scope, in_scope_id }
}
pub(super) fn in_scope_id(self) -> L {
self.in_scope_id
}
}
/// Uniquely identifies an [`ast::Expr`] in a [`FileScopeId`].
#[newtype_index]
pub struct ScopeExpressionId;

View file

@ -10,16 +10,16 @@ use ruff_python_ast::visitor::{walk_expr, walk_stmt, Visitor};
use crate::name::Name;
use crate::red_knot::node_key::NodeKey;
use crate::red_knot::semantic_index::ast_ids::{
AstIdsBuilder, ScopeAssignmentId, ScopeClassId, ScopeFunctionId, ScopeImportFromId,
AstId, AstIdsBuilder, ScopeAssignmentId, ScopeClassId, ScopeFunctionId, ScopeImportFromId,
ScopeImportId, ScopeNamedExprId,
};
use crate::red_knot::semantic_index::definition::{
Definition, ImportDefinition, ImportFromDefinition,
};
use crate::red_knot::semantic_index::symbol::{
FileScopeId, FileSymbolId, Scope, ScopeKind, ScopeSymbolId, SymbolFlags, SymbolTableBuilder,
FileScopeId, FileSymbolId, Scope, ScopedSymbolId, SymbolFlags, SymbolTableBuilder,
};
use crate::red_knot::semantic_index::SemanticIndex;
use crate::red_knot::semantic_index::{NodeWithScopeId, SemanticIndex};
pub(super) struct SemanticIndexBuilder<'a> {
// Builder state
@ -33,6 +33,7 @@ pub(super) struct SemanticIndexBuilder<'a> {
symbol_tables: IndexVec<FileScopeId, SymbolTableBuilder>,
ast_ids: IndexVec<FileScopeId, AstIdsBuilder>,
expression_scopes: FxHashMap<NodeKey, FileScopeId>,
scope_nodes: IndexVec<FileScopeId, NodeWithScopeId>,
}
impl<'a> SemanticIndexBuilder<'a> {
@ -46,10 +47,11 @@ impl<'a> SemanticIndexBuilder<'a> {
symbol_tables: IndexVec::new(),
ast_ids: IndexVec::new(),
expression_scopes: FxHashMap::default(),
scope_nodes: IndexVec::new(),
};
builder.push_scope_with_parent(
ScopeKind::Module,
NodeWithScopeId::Module,
&Name::new_static("<module>"),
None,
None,
@ -68,18 +70,18 @@ impl<'a> SemanticIndexBuilder<'a> {
fn push_scope(
&mut self,
scope_kind: ScopeKind,
node: NodeWithScopeId,
name: &Name,
defining_symbol: Option<FileSymbolId>,
definition: Option<Definition>,
) {
let parent = self.current_scope();
self.push_scope_with_parent(scope_kind, name, defining_symbol, definition, Some(parent));
self.push_scope_with_parent(node, name, defining_symbol, definition, Some(parent));
}
fn push_scope_with_parent(
&mut self,
scope_kind: ScopeKind,
node: NodeWithScopeId,
name: &Name,
defining_symbol: Option<FileSymbolId>,
definition: Option<Definition>,
@ -92,13 +94,17 @@ impl<'a> SemanticIndexBuilder<'a> {
parent,
defining_symbol,
definition,
kind: scope_kind,
kind: node.scope_kind(),
descendents: children_start..children_start,
};
let scope_id = self.scopes.push(scope);
self.symbol_tables.push(SymbolTableBuilder::new());
self.ast_ids.push(AstIdsBuilder::new());
let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new());
let scope_node_id = self.scope_nodes.push(node);
debug_assert_eq!(ast_id_scope, scope_id);
debug_assert_eq!(scope_id, scope_node_id);
self.scope_stack.push(scope_id);
}
@ -120,11 +126,10 @@ impl<'a> SemanticIndexBuilder<'a> {
&mut self.ast_ids[scope_id]
}
fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopeSymbolId {
let scope = self.current_scope();
fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId {
let symbol_table = self.current_symbol_table();
symbol_table.add_or_update_symbol(name, scope, flags, None)
symbol_table.add_or_update_symbol(name, flags, None)
}
fn add_or_update_symbol_with_definition(
@ -132,27 +137,32 @@ impl<'a> SemanticIndexBuilder<'a> {
name: Name,
definition: Definition,
) -> ScopeSymbolId {
let scope = self.current_scope();
) -> ScopedSymbolId {
let symbol_table = self.current_symbol_table();
symbol_table.add_or_update_symbol(name, scope, SymbolFlags::IS_DEFINED, Some(definition))
symbol_table.add_or_update_symbol(name, SymbolFlags::IS_DEFINED, Some(definition))
}
fn with_type_params(
&mut self,
name: &Name,
params: &Option<Box<ast::TypeParams>>,
definition: Option<Definition>,
with_params: &WithTypeParams,
defining_symbol: FileSymbolId,
nested: impl FnOnce(&mut Self) -> FileScopeId,
) -> FileScopeId {
if let Some(type_params) = params {
let type_params = with_params.type_parameters();
if let Some(type_params) = type_params {
let type_node = match with_params {
WithTypeParams::ClassDef { id, .. } => NodeWithScopeId::ClassTypeParams(*id),
WithTypeParams::FunctionDef { id, .. } => NodeWithScopeId::FunctionTypeParams(*id),
};
self.push_scope(
ScopeKind::Annotation,
type_node,
name,
Some(defining_symbol),
definition,
Some(with_params.definition()),
);
for type_param in &type_params.type_params {
let name = match type_param {
@ -163,9 +173,10 @@ impl<'a> SemanticIndexBuilder<'a> {
self.add_or_update_symbol(Name::new(name), SymbolFlags::IS_DEFINED);
}
}
let nested_scope = nested(self);
if params.is_some() {
if type_params.is_some() {
self.pop_scope();
}
@ -198,10 +209,12 @@ impl<'a> SemanticIndexBuilder<'a> {
ast_ids.shrink_to_fit();
symbol_tables.shrink_to_fit();
self.expression_scopes.shrink_to_fit();
self.scope_nodes.shrink_to_fit();
SemanticIndex {
symbol_tables,
scopes: self.scopes,
scope_nodes: self.scope_nodes,
ast_ids,
expression_scopes: self.expression_scopes,
}
@ -223,7 +236,8 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
self.visit_decorator(decorator);
}
let name = Name::new(&function_def.name.id);
let definition = Definition::FunctionDef(ScopeFunctionId(statement_id));
let function_id = ScopeFunctionId(statement_id);
let definition = Definition::FunctionDef(function_id);
let scope = self.current_scope();
let symbol = FileSymbolId::new(
scope,
@ -232,8 +246,10 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
self.with_type_params(
&name,
&function_def.type_params,
Some(definition),
&WithTypeParams::FunctionDef {
node: function_def,
id: AstId::new(scope, function_id),
},
symbol,
|builder| {
builder.visit_parameters(&function_def.parameters);
@ -242,7 +258,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
}
builder.push_scope(
ScopeKind::Function,
NodeWithScopeId::Function(AstId::new(scope, function_id)),
&name,
Some(symbol),
Some(definition),
@ -258,21 +274,36 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
}
let name = Name::new(&class.name.id);
let definition = Definition::from(ScopeClassId(statement_id));
let class_id = ScopeClassId(statement_id);
let definition = Definition::from(class_id);
let scope = self.current_scope();
let id = FileSymbolId::new(
self.current_scope(),
self.add_or_update_symbol_with_definition(name.clone(), definition),
);
self.with_type_params(&name, &class.type_params, Some(definition), id, |builder| {
if let Some(arguments) = &class.arguments {
builder.visit_arguments(arguments);
}
self.with_type_params(
&name,
&WithTypeParams::ClassDef {
node: class,
id: AstId::new(scope, class_id),
},
id,
|builder| {
if let Some(arguments) = &class.arguments {
builder.visit_arguments(arguments);
}
builder.push_scope(ScopeKind::Class, &name, Some(id), Some(definition));
builder.visit_body(&class.body);
builder.push_scope(
NodeWithScopeId::Class(AstId::new(scope, class_id)),
&name,
Some(id),
Some(definition),
);
builder.visit_body(&class.body);
builder.pop_scope()
});
builder.pop_scope()
},
);
}
ast::Stmt::Import(ast::StmtImport { names, .. }) => {
for (i, alias) in names.iter().enumerate() {
@ -396,3 +427,30 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
}
}
}
enum WithTypeParams<'a> {
ClassDef {
node: &'a ast::StmtClassDef,
id: AstId<ScopeClassId>,
},
FunctionDef {
node: &'a ast::StmtFunctionDef,
id: AstId<ScopeFunctionId>,
},
}
impl<'a> WithTypeParams<'a> {
fn type_parameters(&self) -> Option<&'a ast::TypeParams> {
match self {
WithTypeParams::ClassDef { node, .. } => node.type_params.as_deref(),
WithTypeParams::FunctionDef { node, .. } => node.type_params.as_deref(),
}
}
fn definition(&self) -> Definition {
match self {
WithTypeParams::ClassDef { id, .. } => Definition::ClassDef(id.in_scope_id()),
WithTypeParams::FunctionDef { id, .. } => Definition::FunctionDef(id.in_scope_id()),
}
}
}

View file

@ -3,7 +3,7 @@ use crate::red_knot::semantic_index::ast_ids::{
ScopeImportFromId, ScopeImportId, ScopeNamedExprId,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum Definition {
Import(ImportDefinition),
ImportFrom(ImportFromDefinition),
@ -59,18 +59,18 @@ impl From<ScopeNamedExprId> for Definition {
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct ImportDefinition {
pub(super) import_id: ScopeImportId,
pub(crate) import_id: ScopeImportId,
/// Index into [`ruff_python_ast::StmtImport::names`].
pub(super) alias: u32,
pub(crate) alias: u32,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct ImportFromDefinition {
pub(super) import_id: ScopeImportFromId,
pub(crate) import_id: ScopeImportFromId,
/// Index into [`ruff_python_ast::StmtImportFrom::names`].
pub(super) name: u32,
pub(crate) name: u32,
}

View file

@ -15,24 +15,21 @@ use ruff_index::{newtype_index, IndexVec};
use crate::name::Name;
use crate::red_knot::semantic_index::definition::Definition;
use crate::red_knot::semantic_index::{scopes_map, symbol_table, SymbolMap};
use crate::red_knot::semantic_index::{root_scope, semantic_index, symbol_table, SymbolMap};
use crate::Db;
#[derive(Eq, PartialEq, Debug)]
pub struct Symbol {
name: Name,
flags: SymbolFlags,
scope: FileScopeId,
/// The nodes that define this symbol, in source order.
definitions: SmallVec<[Definition; 4]>,
}
impl Symbol {
fn new(name: Name, scope: FileScopeId, definition: Option<Definition>) -> Self {
fn new(name: Name, definition: Option<Definition>) -> Self {
Self {
name,
scope,
flags: SymbolFlags::empty(),
definitions: definition.into_iter().collect(),
}
@ -51,11 +48,6 @@ impl Symbol {
&self.name
}
/// The scope in which this symbol is defined.
pub fn scope(&self) -> FileScopeId {
self.scope
}
/// Is the symbol used in its containing scope?
pub fn is_used(&self) -> bool {
self.flags.contains(SymbolFlags::IS_USED)
@ -84,62 +76,72 @@ bitflags! {
}
/// ID that uniquely identifies a public symbol defined in a module's root scope.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
#[salsa::tracked]
pub struct PublicSymbolId {
scope: ScopeId,
symbol: ScopeSymbolId,
}
impl PublicSymbolId {
pub(crate) fn new(scope: ScopeId, symbol: ScopeSymbolId) -> Self {
Self { scope, symbol }
}
pub fn scope(self) -> ScopeId {
self.scope
}
pub(crate) fn scope_symbol(self) -> ScopeSymbolId {
self.symbol
}
}
impl From<PublicSymbolId> for ScopeSymbolId {
fn from(val: PublicSymbolId) -> Self {
val.scope_symbol()
}
#[id]
pub(crate) file: VfsFile,
#[id]
pub(crate) scoped_symbol_id: ScopedSymbolId,
}
/// ID that uniquely identifies a symbol in a file.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct FileSymbolId {
scope: FileScopeId,
symbol: ScopeSymbolId,
scoped_symbol_id: ScopedSymbolId,
}
impl FileSymbolId {
pub(super) fn new(scope: FileScopeId, symbol: ScopeSymbolId) -> Self {
Self { scope, symbol }
pub(super) fn new(scope: FileScopeId, symbol: ScopedSymbolId) -> Self {
Self {
scope,
scoped_symbol_id: symbol,
}
}
pub fn scope(self) -> FileScopeId {
self.scope
}
pub(crate) fn symbol(self) -> ScopeSymbolId {
self.symbol
pub(crate) fn scoped_symbol_id(self) -> ScopedSymbolId {
self.scoped_symbol_id
}
}
impl From<FileSymbolId> for ScopeSymbolId {
impl From<FileSymbolId> for ScopedSymbolId {
fn from(val: FileSymbolId) -> Self {
val.symbol()
val.scoped_symbol_id()
}
}
/// Symbol ID that uniquely identifies a symbol inside a [`Scope`].
#[newtype_index]
pub(crate) struct ScopeSymbolId;
pub struct ScopedSymbolId;
impl ScopedSymbolId {
/// Converts the symbol to a public symbol.
///
/// # Panics
/// May panic if the symbol does not belong to `file` or is not a symbol of `file`'s root scope.
pub(crate) fn to_public_symbol(self, db: &dyn Db, file: VfsFile) -> PublicSymbolId {
let symbols = public_symbols_map(db, file);
symbols.public(self)
}
}
/// Returns a mapping from [`FileScopeId`] to globally unique [`ScopeId`].
#[salsa::tracked(return_ref)]
pub(crate) fn scopes_map(db: &dyn Db, file: VfsFile) -> ScopesMap {
let index = semantic_index(db, file);
let scopes: IndexVec<_, _> = index
.scopes
.indices()
.map(|id| ScopeId::new(db, file, id))
.collect();
ScopesMap { scopes }
}
/// Maps from the file specific [`FileScopeId`] to the global [`ScopeId`] that can be used as a Salsa query parameter.
///
@ -152,13 +154,37 @@ pub(crate) struct ScopesMap {
}
impl ScopesMap {
pub(super) fn new(scopes: IndexVec<FileScopeId, ScopeId>) -> Self {
Self { scopes }
}
/// Gets the program-wide unique scope id for the given file specific `scope_id`.
fn get(&self, scope_id: FileScopeId) -> ScopeId {
self.scopes[scope_id]
fn get(&self, scope: FileScopeId) -> ScopeId {
self.scopes[scope]
}
}
#[salsa::tracked(return_ref)]
pub(crate) fn public_symbols_map(db: &dyn Db, file: VfsFile) -> PublicSymbolsMap {
let module_scope = root_scope(db, file);
let symbols = symbol_table(db, module_scope);
let public_symbols: IndexVec<_, _> = symbols
.symbol_ids()
.map(|id| PublicSymbolId::new(db, file, id))
.collect();
PublicSymbolsMap {
symbols: public_symbols,
}
}
/// Maps [`LocalSymbolId`] of a file's root scope to the corresponding [`PublicSymbolId`] (Salsa ingredients).
#[derive(Eq, PartialEq, Debug)]
pub(crate) struct PublicSymbolsMap {
symbols: IndexVec<ScopedSymbolId, PublicSymbolId>,
}
impl PublicSymbolsMap {
/// Resolve the [`PublicSymbolId`] for the module-level `symbol_id`.
fn public(&self, symbol_id: ScopedSymbolId) -> PublicSymbolId {
self.symbols[symbol_id]
}
}
@ -166,18 +192,10 @@ impl ScopesMap {
#[salsa::tracked]
pub struct ScopeId {
#[allow(clippy::used_underscore_binding)]
#[id]
pub file: VfsFile,
pub scope_id: FileScopeId,
}
impl ScopeId {
/// Resolves the symbol named `name` in this scope.
pub fn symbol(self, db: &dyn Db, name: &str) -> Option<PublicSymbolId> {
let symbol_table = symbol_table(db, self);
let in_scope_id = symbol_table.symbol_id_by_name(name)?;
Some(PublicSymbolId::new(self, in_scope_id))
}
#[id]
pub file_scope_id: FileScopeId,
}
/// ID that uniquely identifies a scope inside of a module.
@ -239,7 +257,7 @@ pub enum ScopeKind {
#[derive(Debug)]
pub struct SymbolTable {
/// The symbols in this scope.
symbols: IndexVec<ScopeSymbolId, Symbol>,
symbols: IndexVec<ScopedSymbolId, Symbol>,
/// The symbols indexed by name.
symbols_by_name: SymbolMap,
@ -257,12 +275,12 @@ impl SymbolTable {
self.symbols.shrink_to_fit();
}
pub(crate) fn symbol(&self, symbol_id: impl Into<ScopeSymbolId>) -> &Symbol {
pub(crate) fn symbol(&self, symbol_id: impl Into<ScopedSymbolId>) -> &Symbol {
&self.symbols[symbol_id.into()]
}
#[allow(unused)]
pub(crate) fn symbol_ids(&self) -> impl Iterator<Item = ScopeSymbolId> {
pub(crate) fn symbol_ids(&self) -> impl Iterator<Item = ScopedSymbolId> {
self.symbols.indices()
}
@ -277,8 +295,8 @@ impl SymbolTable {
Some(self.symbol(id))
}
/// Returns the [`ScopeSymbolId`] of the symbol named `name`.
pub(crate) fn symbol_id_by_name(&self, name: &str) -> Option<ScopeSymbolId> {
/// Returns the [`ScopedSymbolId`] of the symbol named `name`.
pub(crate) fn symbol_id_by_name(&self, name: &str) -> Option<ScopedSymbolId> {
let (id, ()) = self
.symbols_by_name
.raw_entry()
@ -320,10 +338,9 @@ impl SymbolTableBuilder {
pub(super) fn add_or_update_symbol(
&mut self,
name: Name,
scope: FileScopeId,
flags: SymbolFlags,
definition: Option<Definition>,
) -> ScopeSymbolId {
) -> ScopedSymbolId {
let hash = SymbolTable::hash_name(&name);
let entry = self
.table
@ -343,7 +360,7 @@ impl SymbolTableBuilder {
*entry.key()
}
RawEntryMut::Vacant(entry) => {
let mut symbol = Symbol::new(name, scope, definition);
let mut symbol = Symbol::new(name, definition);
symbol.insert_flags(flags);
let id = self.table.symbols.push(symbol);

View file

@ -0,0 +1,684 @@
use salsa::DebugWithDb;
use ruff_db::parsed::parsed_module;
use ruff_db::vfs::VfsFile;
use ruff_index::newtype_index;
use ruff_python_ast as ast;
use crate::name::Name;
use crate::red_knot::semantic_index::ast_ids::{AstIdNode, ScopeAstIdNode};
use crate::red_knot::semantic_index::symbol::{FileScopeId, PublicSymbolId, ScopeId};
use crate::red_knot::semantic_index::{
public_symbol, root_scope, semantic_index, symbol_table, NodeWithScopeId,
};
use crate::red_knot::types::infer::{TypeInference, TypeInferenceBuilder};
use crate::red_knot::FxIndexSet;
use crate::Db;
mod display;
mod infer;
/// Infers the type of `expr`.
///
/// Calling this function from a salsa query adds a dependency on [`semantic_index`]
/// which changes with every AST change. That's why you should only call
/// this function for the current file that's being analyzed and not for
/// a dependency (or the query reruns whenever a dependency change).
///
/// Prefer [`public_symbol_ty`] when resolving the type of symbol from another file.
#[tracing::instrument(level = "debug", skip(db))]
pub(crate) fn expression_ty(db: &dyn Db, file: VfsFile, expression: &ast::Expr) -> Type {
let index = semantic_index(db, file);
let file_scope = index.expression_scope_id(expression);
let expression_id = expression.scope_ast_id(db, file, file_scope);
let scope = file_scope.to_scope_id(db, file);
infer_types(db, scope).expression_ty(expression_id)
}
/// Infers the type of a public symbol.
///
/// This is a Salsa query to get symbol-level invalidation instead of file-level dependency invalidation.
/// Without this being a query, changing any public type of a module would invalidate the type inference
/// for the module scope of its dependents and the transitive dependents because.
///
/// For example if we have
/// ```python
/// # a.py
/// import x from b
///
/// # b.py
///
/// x = 20
/// ```
///
/// And x is now changed from `x = 20` to `x = 30`. The following happens:
///
/// * The module level types of `b.py` change because `x` now is a `Literal[30]`.
/// * The module level types of `a.py` change because the imported symbol `x` now has a `Literal[30]` type
/// * The module level types of any dependents of `a.py` change because the imported symbol `x` now has a `Literal[30]` type
/// * And so on for all transitive dependencies.
///
/// This being a query ensures that the invalidation short-circuits if the type of this symbol didn't change.
#[salsa::tracked]
pub(crate) fn public_symbol_ty(db: &dyn Db, symbol: PublicSymbolId) -> Type {
let _ = tracing::debug_span!("public_symbol_ty", "{:?}", symbol.debug(db));
let file = symbol.file(db);
let scope = root_scope(db, file);
let inference = infer_types(db, scope);
inference.symbol_ty(symbol.scoped_symbol_id(db))
}
/// Shorthand for [`public_symbol_ty()`] that takes a symbol name instead of a [`PublicSymbolId`].
pub fn public_symbol_ty_by_name(db: &dyn Db, file: VfsFile, name: &str) -> Option<Type> {
let symbol = public_symbol(db, file, name)?;
Some(public_symbol_ty(db, symbol))
}
/// Infers all types for `scope`.
#[salsa::tracked(return_ref)]
pub(crate) fn infer_types(db: &dyn Db, scope: ScopeId) -> TypeInference {
let file = scope.file(db);
// Using the index here is fine because the code below depends on the AST anyway.
// The isolation of the query is by the return inferred types.
let index = semantic_index(db, file);
let scope_id = scope.file_scope_id(db);
let node = index.scope_node(scope_id);
let mut context = TypeInferenceBuilder::new(db, scope, index);
match node {
NodeWithScopeId::Module => {
let parsed = parsed_module(db.upcast(), file);
context.infer_module(parsed.syntax());
}
NodeWithScopeId::Class(class_id) => {
let class = ast::StmtClassDef::lookup(db, file, class_id);
context.infer_class_body(class);
}
NodeWithScopeId::ClassTypeParams(class_id) => {
let class = ast::StmtClassDef::lookup(db, file, class_id);
context.infer_class_type_params(class);
}
NodeWithScopeId::Function(function_id) => {
let function = ast::StmtFunctionDef::lookup(db, file, function_id);
context.infer_function_body(function);
}
NodeWithScopeId::FunctionTypeParams(function_id) => {
let function = ast::StmtFunctionDef::lookup(db, file, function_id);
context.infer_function_type_params(function);
}
}
context.finish()
}
/// unique ID for a type
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum Type {
/// the dynamic type: a statically-unknown set of values
Any,
/// the empty set of values
Never,
/// unknown type (no annotation)
/// equivalent to Any, or to object in strict mode
Unknown,
/// name is not bound to any value
Unbound,
/// the None object (TODO remove this in favor of Instance(types.NoneType)
None,
/// a specific function object
Function(TypeId<ScopedFunctionTypeId>),
/// a specific module object
Module(TypeId<ScopedModuleTypeId>),
/// a specific class object
Class(TypeId<ScopedClassTypeId>),
/// the set of Python objects with the given class in their __class__'s method resolution order
Instance(TypeId<ScopedClassTypeId>),
Union(TypeId<ScopedUnionTypeId>),
Intersection(TypeId<ScopedIntersectionTypeId>),
IntLiteral(i64),
// TODO protocols, callable types, overloads, generics, type vars
}
impl Type {
pub const fn is_unbound(&self) -> bool {
matches!(self, Type::Unbound)
}
pub const fn is_unknown(&self) -> bool {
matches!(self, Type::Unknown)
}
pub fn member(&self, context: &TypingContext, name: &Name) -> Option<Type> {
match self {
Type::Any => Some(Type::Any),
Type::Never => todo!("attribute lookup on Never type"),
Type::Unknown => Some(Type::Unknown),
Type::Unbound => todo!("attribute lookup on Unbound type"),
Type::None => todo!("attribute lookup on None type"),
Type::Function(_) => todo!("attribute lookup on Function type"),
Type::Module(module) => module.member(context, name),
Type::Class(class) => class.class_member(context, name),
Type::Instance(_) => {
// TODO MRO? get_own_instance_member, get_instance_member
todo!("attribute lookup on Instance type")
}
Type::Union(union_id) => {
let _union = union_id.lookup(context);
// TODO perform the get_member on each type in the union
// TODO return the union of those results
// TODO if any of those results is `None` then include Unknown in the result union
todo!("attribute lookup on Union type")
}
Type::Intersection(_) => {
// TODO perform the get_member on each type in the intersection
// TODO return the intersection of those results
todo!("attribute lookup on Intersection type")
}
Type::IntLiteral(_) => {
// TODO raise error
Some(Type::Unknown)
}
}
}
}
/// ID that uniquely identifies a type in a program.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct TypeId<L> {
/// The scope in which this type is defined or was created.
scope: ScopeId,
/// The type's local ID in its scope.
scoped: L,
}
impl<Id> TypeId<Id>
where
Id: Copy,
{
pub fn scope(&self) -> ScopeId {
self.scope
}
pub fn scoped_id(&self) -> Id {
self.scoped
}
/// Resolves the type ID to the actual type.
pub(crate) fn lookup<'a>(self, context: &'a TypingContext) -> &'a Id::Ty
where
Id: ScopedTypeId,
{
let types = context.types(self.scope);
self.scoped.lookup_scoped(types)
}
}
/// ID that uniquely identifies a type in a scope.
pub(crate) trait ScopedTypeId {
/// The type that this ID points to.
type Ty;
/// Looks up the type in `index`.
///
/// ## Panics
/// May panic if this type is from another scope than `index`, or might just return an invalid type.
fn lookup_scoped(self, index: &TypeInference) -> &Self::Ty;
}
/// ID uniquely identifying a function type in a `scope`.
#[newtype_index]
pub struct ScopedFunctionTypeId;
impl ScopedTypeId for ScopedFunctionTypeId {
type Ty = FunctionType;
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
types.function_ty(self)
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct FunctionType {
/// name of the function at definition
name: Name,
/// types of all decorators on this function
decorators: Vec<Type>,
}
impl FunctionType {
fn name(&self) -> &str {
self.name.as_str()
}
#[allow(unused)]
pub(crate) fn decorators(&self) -> &[Type] {
self.decorators.as_slice()
}
}
#[newtype_index]
pub struct ScopedClassTypeId;
impl ScopedTypeId for ScopedClassTypeId {
type Ty = ClassType;
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
types.class_ty(self)
}
}
impl TypeId<ScopedClassTypeId> {
/// Returns the class member of this class named `name`.
///
/// The member resolves to a member of the class itself or any of its bases.
fn class_member(self, context: &TypingContext, name: &Name) -> Option<Type> {
if let Some(member) = self.own_class_member(context, name) {
return Some(member);
}
let class = self.lookup(context);
for base in &class.bases {
if let Some(member) = base.member(context, name) {
return Some(member);
}
}
None
}
/// Returns the inferred type of the class member named `name`.
fn own_class_member(self, context: &TypingContext, name: &Name) -> Option<Type> {
let class = self.lookup(context);
let symbols = symbol_table(context.db, class.body_scope);
let symbol = symbols.symbol_id_by_name(name)?;
let types = context.types(class.body_scope);
Some(types.symbol_ty(symbol))
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct ClassType {
/// Name of the class at definition
name: Name,
/// Types of all class bases
bases: Vec<Type>,
body_scope: ScopeId,
}
impl ClassType {
fn name(&self) -> &str {
self.name.as_str()
}
#[allow(unused)]
pub(super) fn bases(&self) -> &[Type] {
self.bases.as_slice()
}
}
#[newtype_index]
pub struct ScopedUnionTypeId;
impl ScopedTypeId for ScopedUnionTypeId {
type Ty = UnionType;
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
types.union_ty(self)
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct UnionType {
// the union type includes values in any of these types
elements: FxIndexSet<Type>,
}
struct UnionTypeBuilder<'a> {
elements: FxIndexSet<Type>,
context: &'a TypingContext<'a>,
}
impl<'a> UnionTypeBuilder<'a> {
fn new(context: &'a TypingContext<'a>) -> Self {
Self {
context,
elements: FxIndexSet::default(),
}
}
/// Adds a type to this union.
fn add(mut self, ty: Type) -> Self {
match ty {
Type::Union(union_id) => {
let union = union_id.lookup(self.context);
self.elements.extend(&union.elements);
}
_ => {
self.elements.insert(ty);
}
}
self
}
fn build(self) -> UnionType {
UnionType {
elements: self.elements,
}
}
}
#[newtype_index]
pub struct ScopedIntersectionTypeId;
impl ScopedTypeId for ScopedIntersectionTypeId {
type Ty = IntersectionType;
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
types.intersection_ty(self)
}
}
// Negation types aren't expressible in annotations, and are most likely to arise from type
// narrowing along with intersections (e.g. `if not isinstance(...)`), so we represent them
// directly in intersections rather than as a separate type. This sacrifices some efficiency in the
// case where a Not appears outside an intersection (unclear when that could even happen, but we'd
// have to represent it as a single-element intersection if it did) in exchange for better
// efficiency in the within-intersection case.
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct IntersectionType {
// the intersection type includes only values in all of these types
positive: FxIndexSet<Type>,
// the intersection type does not include any value in any of these types
negative: FxIndexSet<Type>,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct ScopedModuleTypeId;
impl ScopedTypeId for ScopedModuleTypeId {
type Ty = ModuleType;
fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty {
types.module_ty()
}
}
impl TypeId<ScopedModuleTypeId> {
fn member(self, context: &TypingContext, name: &Name) -> Option<Type> {
context.public_symbol_ty(self.scope.file(context.db), name)
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct ModuleType {
file: VfsFile,
}
/// Context in which to resolve types.
///
/// This abstraction is necessary to support a uniform API that can be used
/// while in the process of building the type inference structure for a scope
/// but also when all types should be resolved by querying the db.
pub struct TypingContext<'a> {
db: &'a dyn Db,
/// The Local type inference scope that is in the process of being built.
///
/// Bypass the `db` when resolving the types for this scope.
local: Option<(ScopeId, &'a TypeInference)>,
}
impl<'a> TypingContext<'a> {
/// Creates a context that resolves all types by querying the db.
#[allow(unused)]
pub(super) fn global(db: &'a dyn Db) -> Self {
Self { db, local: None }
}
/// Creates a context that by-passes the `db` when resolving types from `scope_id` and instead uses `types`.
fn scoped(db: &'a dyn Db, scope_id: ScopeId, types: &'a TypeInference) -> Self {
Self {
db,
local: Some((scope_id, types)),
}
}
/// Returns the [`TypeInference`] results (not guaranteed to be complete) for `scope_id`.
fn types(&self, scope_id: ScopeId) -> &'a TypeInference {
if let Some((scope, local_types)) = self.local {
if scope == scope_id {
return local_types;
}
}
infer_types(self.db, scope_id)
}
fn module_ty(&self, file: VfsFile) -> Type {
let scope = root_scope(self.db, file);
Type::Module(TypeId {
scope,
scoped: ScopedModuleTypeId,
})
}
/// Resolves the public type of a symbol named `name` defined in `file`.
///
/// This function calls [`public_symbol_ty`] if the local scope isn't the module scope of `file`.
/// It otherwise tries to resolve the symbol type locally.
fn public_symbol_ty(&self, file: VfsFile, name: &Name) -> Option<Type> {
let symbol = public_symbol(self.db, file, name)?;
if let Some((scope, local_types)) = self.local {
if scope.file_scope_id(self.db) == FileScopeId::root() && scope.file(self.db) == file {
return Some(local_types.symbol_ty(symbol.scoped_symbol_id(self.db)));
}
}
Some(public_symbol_ty(self.db, symbol))
}
}
#[cfg(test)]
mod tests {
use ruff_db::file_system::FileSystemPathBuf;
use ruff_db::parsed::parsed_module;
use ruff_db::vfs::system_path_to_file;
use crate::db::tests::{
assert_will_not_run_function_query, assert_will_run_function_query, TestDb,
};
use crate::module::resolver::{set_module_resolution_settings, ModuleResolutionSettings};
use crate::red_knot::semantic_index::root_scope;
use crate::red_knot::types::{
expression_ty, infer_types, public_symbol_ty_by_name, TypingContext,
};
fn setup_db() -> TestDb {
let mut db = TestDb::new();
set_module_resolution_settings(
&mut db,
ModuleResolutionSettings {
extra_paths: vec![],
workspace_root: FileSystemPathBuf::from("/src"),
site_packages: None,
custom_typeshed: None,
},
);
db
}
#[test]
fn local_inference() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_file("/src/a.py", "x = 10")?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let parsed = parsed_module(&db, a);
let statement = parsed.suite().first().unwrap().as_assign_stmt().unwrap();
let literal_ty = expression_ty(&db, a, &statement.value);
assert_eq!(
format!("{}", literal_ty.display(&TypingContext::global(&db))),
"Literal[10]"
);
Ok(())
}
#[test]
fn dependency_public_symbol_type_change() -> anyhow::Result<()> {
let mut db = setup_db();
db.memory_file_system().write_files([
("/src/a.py", "from foo import x"),
("/src/foo.py", "x = 10\ndef foo(): ..."),
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
// Change `x` to a different value
db.memory_file_system()
.write_file("/src/foo.py", "x = 20\ndef foo(): ...")?;
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
foo.touch(&mut db);
let a = system_path_to_file(&db, "/src/a.py").unwrap();
db.clear_salsa_events();
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty_2.display(&TypingContext::global(&db)).to_string(),
"Literal[20]"
);
let events = db.take_salsa_events();
let a_root_scope = root_scope(&db, a);
assert_will_run_function_query::<infer_types, _, _>(
&db,
|ty| &ty.function,
a_root_scope,
&events,
);
Ok(())
}
#[test]
fn dependency_non_public_symbol_change() -> anyhow::Result<()> {
let mut db = setup_db();
db.memory_file_system().write_files([
("/src/a.py", "from foo import x"),
("/src/foo.py", "x = 10\ndef foo(): y = 1"),
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
db.memory_file_system()
.write_file("/src/foo.py", "x = 10\ndef foo(): pass")?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
foo.touch(&mut db);
db.clear_salsa_events();
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty_2.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
let events = db.take_salsa_events();
let a_root_scope = root_scope(&db, a);
assert_will_not_run_function_query::<infer_types, _, _>(
&db,
|ty| &ty.function,
a_root_scope,
&events,
);
Ok(())
}
#[test]
fn dependency_unrelated_public_symbol() -> anyhow::Result<()> {
let mut db = setup_db();
db.memory_file_system().write_files([
("/src/a.py", "from foo import x"),
("/src/foo.py", "x = 10\ny = 20"),
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
db.memory_file_system()
.write_file("/src/foo.py", "x = 10\ny = 30")?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
foo.touch(&mut db);
db.clear_salsa_events();
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty_2.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
let events = db.take_salsa_events();
let a_root_scope = root_scope(&db, a);
assert_will_not_run_function_query::<infer_types, _, _>(
&db,
|ty| &ty.function,
a_root_scope,
&events,
);
Ok(())
}
}

View file

@ -0,0 +1,175 @@
//! Display implementations for types.
use std::fmt::{Display, Formatter};
use crate::red_knot::types::{IntersectionType, Type, TypingContext, UnionType};
impl Type {
pub fn display<'a>(&'a self, context: &'a TypingContext) -> DisplayType<'a> {
DisplayType { ty: self, context }
}
}
#[derive(Copy, Clone)]
pub struct DisplayType<'a> {
ty: &'a Type,
context: &'a TypingContext<'a>,
}
impl Display for DisplayType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self.ty {
Type::Any => f.write_str("Any"),
Type::Never => f.write_str("Never"),
Type::Unknown => f.write_str("Unknown"),
Type::Unbound => f.write_str("Unbound"),
Type::None => f.write_str("None"),
Type::Module(module_id) => {
write!(
f,
"<module '{:?}'>",
module_id
.scope
.file(self.context.db)
.path(self.context.db.upcast())
)
}
// TODO functions and classes should display using a fully qualified name
Type::Class(class_id) => {
let class = class_id.lookup(self.context);
f.write_str("Literal[")?;
f.write_str(class.name())?;
f.write_str("]")
}
Type::Instance(class_id) => {
let class = class_id.lookup(self.context);
f.write_str(class.name())
}
Type::Function(function_id) => {
let function = function_id.lookup(self.context);
f.write_str(function.name())
}
Type::Union(union_id) => {
let union = union_id.lookup(self.context);
union.display(self.context).fmt(f)
}
Type::Intersection(intersection_id) => {
let intersection = intersection_id.lookup(self.context);
intersection.display(self.context).fmt(f)
}
Type::IntLiteral(n) => write!(f, "Literal[{n}]"),
}
}
}
impl std::fmt::Debug for DisplayType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
}
impl UnionType {
fn display<'a>(&'a self, context: &'a TypingContext<'a>) -> DisplayUnionType<'a> {
DisplayUnionType { context, ty: self }
}
}
struct DisplayUnionType<'a> {
ty: &'a UnionType,
context: &'a TypingContext<'a>,
}
impl Display for DisplayUnionType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let union = self.ty;
let (int_literals, other_types): (Vec<Type>, Vec<Type>) = union
.elements
.iter()
.copied()
.partition(|ty| matches!(ty, Type::IntLiteral(_)));
let mut first = true;
if !int_literals.is_empty() {
f.write_str("Literal[")?;
let mut nums: Vec<_> = int_literals
.into_iter()
.filter_map(|ty| {
if let Type::IntLiteral(n) = ty {
Some(n)
} else {
None
}
})
.collect();
nums.sort_unstable();
for num in nums {
if !first {
f.write_str(", ")?;
}
write!(f, "{num}")?;
first = false;
}
f.write_str("]")?;
}
for ty in other_types {
if !first {
f.write_str(" | ")?;
};
first = false;
write!(f, "{}", ty.display(self.context))?;
}
Ok(())
}
}
impl std::fmt::Debug for DisplayUnionType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
}
impl IntersectionType {
fn display<'a>(&'a self, context: &'a TypingContext<'a>) -> DisplayIntersectionType<'a> {
DisplayIntersectionType { ty: self, context }
}
}
struct DisplayIntersectionType<'a> {
ty: &'a IntersectionType,
context: &'a TypingContext<'a>,
}
impl Display for DisplayIntersectionType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut first = true;
for (neg, ty) in self
.ty
.positive
.iter()
.map(|ty| (false, ty))
.chain(self.ty.negative.iter().map(|ty| (true, ty)))
{
if !first {
f.write_str(" & ")?;
};
first = false;
if neg {
f.write_str("~")?;
};
write!(f, "{}", ty.display(self.context))?;
}
Ok(())
}
}
impl std::fmt::Debug for DisplayIntersectionType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
}

View file

@ -0,0 +1,945 @@
use std::sync::Arc;
use rustc_hash::FxHashMap;
use ruff_db::vfs::VfsFile;
use ruff_index::IndexVec;
use ruff_python_ast as ast;
use ruff_python_ast::{ExprContext, TypeParams};
use crate::module::resolver::resolve_module;
use crate::module::ModuleName;
use crate::name::Name;
use crate::red_knot::semantic_index::ast_ids::{ScopeAstIdNode, ScopeExpressionId};
use crate::red_knot::semantic_index::definition::{
Definition, ImportDefinition, ImportFromDefinition,
};
use crate::red_knot::semantic_index::symbol::{
FileScopeId, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable,
};
use crate::red_knot::semantic_index::{symbol_table, ChildrenIter, SemanticIndex};
use crate::red_knot::types::{
ClassType, FunctionType, IntersectionType, ModuleType, ScopedClassTypeId, ScopedFunctionTypeId,
ScopedIntersectionTypeId, ScopedUnionTypeId, Type, TypeId, TypingContext, UnionType,
UnionTypeBuilder,
};
use crate::Db;
/// The inferred types for a single scope.
#[derive(Debug, Eq, PartialEq, Default, Clone)]
pub(crate) struct TypeInference {
/// The type of the module if the scope is a module scope.
module_type: Option<ModuleType>,
/// The types of the defined classes in this scope.
class_types: IndexVec<ScopedClassTypeId, ClassType>,
/// The types of the defined functions in this scope.
function_types: IndexVec<ScopedFunctionTypeId, FunctionType>,
union_types: IndexVec<ScopedUnionTypeId, UnionType>,
intersection_types: IndexVec<ScopedIntersectionTypeId, IntersectionType>,
/// The types of every expression in this scope.
expression_tys: IndexVec<ScopeExpressionId, Type>,
/// The public types of every symbol in this scope.
symbol_tys: IndexVec<ScopedSymbolId, Type>,
}
impl TypeInference {
#[allow(unused)]
pub(super) fn expression_ty(&self, expression: ScopeExpressionId) -> Type {
self.expression_tys[expression]
}
pub(super) fn symbol_ty(&self, symbol: ScopedSymbolId) -> Type {
self.symbol_tys[symbol]
}
pub(super) fn module_ty(&self) -> &ModuleType {
self.module_type.as_ref().unwrap()
}
pub(super) fn class_ty(&self, id: ScopedClassTypeId) -> &ClassType {
&self.class_types[id]
}
pub(super) fn function_ty(&self, id: ScopedFunctionTypeId) -> &FunctionType {
&self.function_types[id]
}
pub(super) fn union_ty(&self, id: ScopedUnionTypeId) -> &UnionType {
&self.union_types[id]
}
pub(super) fn intersection_ty(&self, id: ScopedIntersectionTypeId) -> &IntersectionType {
&self.intersection_types[id]
}
fn shrink_to_fit(&mut self) {
self.class_types.shrink_to_fit();
self.function_types.shrink_to_fit();
self.union_types.shrink_to_fit();
self.intersection_types.shrink_to_fit();
self.expression_tys.shrink_to_fit();
self.symbol_tys.shrink_to_fit();
}
}
/// Builder to infer all types in a [`ScopeId`].
pub(super) struct TypeInferenceBuilder<'a> {
db: &'a dyn Db,
// Cached lookups
index: &'a SemanticIndex,
scope: ScopeId,
file_scope_id: FileScopeId,
file_id: VfsFile,
symbol_table: Arc<SymbolTable>,
/// The type inference results
types: TypeInference,
definition_tys: FxHashMap<Definition, Type>,
children_scopes: ChildrenIter<'a>,
}
impl<'a> TypeInferenceBuilder<'a> {
/// Creates a new builder for inferring the types of `scope`.
pub(super) fn new(db: &'a dyn Db, scope: ScopeId, index: &'a SemanticIndex) -> Self {
let file_scope_id = scope.file_scope_id(db);
let file = scope.file(db);
let children_scopes = index.child_scopes(file_scope_id);
let symbol_table = index.symbol_table(file_scope_id);
Self {
index,
file_scope_id,
file_id: file,
scope,
symbol_table,
db,
types: TypeInference::default(),
definition_tys: FxHashMap::default(),
children_scopes,
}
}
/// Infers the types of a `module`.
pub(super) fn infer_module(&mut self, module: &ast::ModModule) {
self.infer_body(&module.body);
}
pub(super) fn infer_class_type_params(&mut self, class: &ast::StmtClassDef) {
if let Some(type_params) = class.type_params.as_deref() {
self.infer_type_parameters(type_params);
}
}
pub(super) fn infer_class_body(&mut self, class: &ast::StmtClassDef) {
self.infer_body(&class.body);
}
pub(super) fn infer_function_type_params(&mut self, function: &ast::StmtFunctionDef) {
if let Some(type_params) = function.type_params.as_deref() {
self.infer_type_parameters(type_params);
}
}
pub(super) fn infer_function_body(&mut self, function: &ast::StmtFunctionDef) {
self.infer_body(&function.body);
}
fn infer_body(&mut self, suite: &[ast::Stmt]) {
for statement in suite {
self.infer_statement(statement);
}
}
fn infer_statement(&mut self, statement: &ast::Stmt) {
match statement {
ast::Stmt::FunctionDef(function) => self.infer_function_definition_statement(function),
ast::Stmt::ClassDef(class) => self.infer_class_definition_statement(class),
ast::Stmt::Expr(ast::StmtExpr { range: _, value }) => {
self.infer_expression(value);
}
ast::Stmt::If(if_statement) => self.infer_if_statement(if_statement),
ast::Stmt::Assign(assign) => self.infer_assignment_statement(assign),
ast::Stmt::AnnAssign(assign) => self.infer_annotated_assignment_statement(assign),
ast::Stmt::For(for_statement) => self.infer_for_statement(for_statement),
ast::Stmt::Import(import) => self.infer_import_statement(import),
ast::Stmt::ImportFrom(import) => self.infer_import_from_statement(import),
ast::Stmt::Break(_) | ast::Stmt::Continue(_) | ast::Stmt::Pass(_) => {
// No-op
}
_ => {}
}
}
fn infer_function_definition_statement(&mut self, function: &ast::StmtFunctionDef) {
let ast::StmtFunctionDef {
range: _,
is_async: _,
name,
type_params: _,
parameters: _,
returns,
body: _,
decorator_list,
} = function;
let function_id = function.scope_ast_id(self.db, self.file_id, self.file_scope_id);
let decorator_tys = decorator_list
.iter()
.map(|decorator| self.infer_decorator(decorator))
.collect();
// TODO: Infer parameters
if let Some(return_ty) = returns {
self.infer_expression(return_ty);
}
let function_ty = self.function_ty(FunctionType {
name: Name::new(&name.id),
decorators: decorator_tys,
});
// Skip over the function or type params child scope.
let (_, scope) = self.children_scopes.next().unwrap();
assert!(matches!(
scope.kind(),
ScopeKind::Function | ScopeKind::Annotation
));
self.definition_tys
.insert(Definition::FunctionDef(function_id), function_ty);
}
fn infer_class_definition_statement(&mut self, class: &ast::StmtClassDef) {
let ast::StmtClassDef {
range: _,
name,
type_params,
decorator_list,
arguments,
body: _,
} = class;
let class_id = class.scope_ast_id(self.db, self.file_id, self.file_scope_id);
for decorator in decorator_list {
self.infer_decorator(decorator);
}
let bases = arguments
.as_deref()
.map(|arguments| self.infer_arguments(arguments))
.unwrap_or(Vec::new());
// If the class has type parameters, then the class body scope is the first child scope of the type parameter's scope
// Otherwise the next scope must be the class definition scope.
let (class_body_scope_id, class_body_scope) = if type_params.is_some() {
let (type_params_scope, _) = self.children_scopes.next().unwrap();
self.index.child_scopes(type_params_scope).next().unwrap()
} else {
self.children_scopes.next().unwrap()
};
assert_eq!(class_body_scope.kind(), ScopeKind::Class);
let class_ty = self.class_ty(ClassType {
name: Name::new(name),
bases,
body_scope: class_body_scope_id.to_scope_id(self.db, self.file_id),
});
self.definition_tys
.insert(Definition::ClassDef(class_id), class_ty);
}
fn infer_if_statement(&mut self, if_statement: &ast::StmtIf) {
let ast::StmtIf {
range: _,
test,
body,
elif_else_clauses,
} = if_statement;
self.infer_expression(test);
self.infer_body(body);
for clause in elif_else_clauses {
let ast::ElifElseClause {
range: _,
test,
body,
} = clause;
if let Some(test) = &test {
self.infer_expression(test);
}
self.infer_body(body);
}
}
fn infer_assignment_statement(&mut self, assignment: &ast::StmtAssign) {
let ast::StmtAssign {
range: _,
targets,
value,
} = assignment;
let value_ty = self.infer_expression(value);
for target in targets {
self.infer_expression(target);
}
let assign_id = assignment.scope_ast_id(self.db, self.file_id, self.file_scope_id);
// TODO: Handle multiple targets.
self.definition_tys
.insert(Definition::Assignment(assign_id), value_ty);
}
fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) {
let ast::StmtAnnAssign {
range: _,
target,
annotation,
value,
simple: _,
} = assignment;
if let Some(value) = value {
let _ = self.infer_expression(value);
}
let annotation_ty = self.infer_expression(annotation);
self.infer_expression(target);
self.definition_tys.insert(
Definition::AnnotatedAssignment(assignment.scope_ast_id(
self.db,
self.file_id,
self.file_scope_id,
)),
annotation_ty,
);
}
fn infer_for_statement(&mut self, for_statement: &ast::StmtFor) {
let ast::StmtFor {
range: _,
target,
iter,
body,
orelse,
is_async: _,
} = for_statement;
self.infer_expression(iter);
self.infer_expression(target);
self.infer_body(body);
self.infer_body(orelse);
}
fn infer_import_statement(&mut self, import: &ast::StmtImport) {
let ast::StmtImport { range: _, names } = import;
let import_id = import.scope_ast_id(self.db, self.file_id, self.file_scope_id);
for (i, alias) in names.iter().enumerate() {
let ast::Alias {
range: _,
name,
asname: _,
} = alias;
let module_name = ModuleName::new(&name.id);
let module = module_name.and_then(|name| resolve_module(self.db, name));
let module_ty = module
.map(|module| self.typing_context().module_ty(module.file()))
.unwrap_or(Type::Unknown);
self.definition_tys.insert(
Definition::Import(ImportDefinition {
import_id,
alias: u32::try_from(i).unwrap(),
}),
module_ty,
);
}
}
fn infer_import_from_statement(&mut self, import: &ast::StmtImportFrom) {
let ast::StmtImportFrom {
range: _,
module,
names,
level: _,
} = import;
let import_id = import.scope_ast_id(self.db, self.file_id, self.file_scope_id);
let module_name = ModuleName::new(module.as_deref().expect("Support relative imports"));
let module = module_name.and_then(|module_name| resolve_module(self.db, module_name));
let module_ty = module
.map(|module| self.typing_context().module_ty(module.file()))
.unwrap_or(Type::Unknown);
for (i, alias) in names.iter().enumerate() {
let ast::Alias {
range: _,
name,
asname: _,
} = alias;
let ty = module_ty
.member(&self.typing_context(), &Name::new(&name.id))
.unwrap_or(Type::Unknown);
self.definition_tys.insert(
Definition::ImportFrom(ImportFromDefinition {
import_id,
name: u32::try_from(i).unwrap(),
}),
ty,
);
}
}
fn infer_decorator(&mut self, decorator: &ast::Decorator) -> Type {
let ast::Decorator {
range: _,
expression,
} = decorator;
self.infer_expression(expression)
}
fn infer_arguments(&mut self, arguments: &ast::Arguments) -> Vec<Type> {
let mut types = Vec::with_capacity(
arguments
.args
.len()
.saturating_add(arguments.keywords.len()),
);
types.extend(arguments.args.iter().map(|arg| self.infer_expression(arg)));
types.extend(arguments.keywords.iter().map(
|ast::Keyword {
range: _,
arg: _,
value,
}| self.infer_expression(value),
));
types
}
fn infer_expression(&mut self, expression: &ast::Expr) -> Type {
let ty = match expression {
ast::Expr::NoneLiteral(ast::ExprNoneLiteral { range: _ }) => Type::None,
ast::Expr::NumberLiteral(literal) => self.infer_number_literal_expression(literal),
ast::Expr::Name(name) => self.infer_name_expression(name),
ast::Expr::Attribute(attribute) => self.infer_attribute_expression(attribute),
ast::Expr::BinOp(binary) => self.infer_binary_expression(binary),
ast::Expr::Named(named) => self.infer_named_expression(named),
ast::Expr::If(if_expression) => self.infer_if_expression(if_expression),
_ => todo!("expression type resolution for {:?}", expression),
};
self.types.expression_tys.push(ty);
ty
}
#[allow(clippy::unused_self)]
fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type {
let ast::ExprNumberLiteral { range: _, value } = literal;
match value {
ast::Number::Int(n) => {
// TODO support big int literals
n.as_i64().map(Type::IntLiteral).unwrap_or(Type::Unknown)
}
// TODO builtins.float or builtins.complex
_ => Type::Unknown,
}
}
fn infer_named_expression(&mut self, named: &ast::ExprNamed) -> Type {
let ast::ExprNamed {
range: _,
target,
value,
} = named;
let value_ty = self.infer_expression(value);
self.infer_expression(target);
self.definition_tys.insert(
Definition::NamedExpr(named.scope_ast_id(self.db, self.file_id, self.file_scope_id)),
value_ty,
);
value_ty
}
fn infer_if_expression(&mut self, if_expression: &ast::ExprIf) -> Type {
let ast::ExprIf {
range: _,
test,
body,
orelse,
} = if_expression;
self.infer_expression(test);
// TODO detect statically known truthy or falsy test
let body_ty = self.infer_expression(body);
let orelse_ty = self.infer_expression(orelse);
let union = UnionTypeBuilder::new(&self.typing_context())
.add(body_ty)
.add(orelse_ty)
.build();
self.union_ty(union)
}
fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type {
let ast::ExprName { range: _, id, ctx } = name;
match ctx {
ExprContext::Load => {
if let Some(symbol_id) = self
.index
.symbol_table(self.file_scope_id)
.symbol_id_by_name(id)
{
self.local_definition_ty(symbol_id)
} else {
let ancestors = self.index.ancestor_scopes(self.file_scope_id).skip(1);
for (ancestor_id, _) in ancestors {
// TODO: Skip over class scopes unless the they are a immediately-nested type param scope.
// TODO: Support built-ins
let symbol_table =
symbol_table(self.db, ancestor_id.to_scope_id(self.db, self.file_id));
if let Some(_symbol_id) = symbol_table.symbol_id_by_name(id) {
todo!("Return type for symbol from outer scope");
}
}
Type::Unknown
}
}
ExprContext::Del => Type::None,
ExprContext::Invalid => Type::Unknown,
ExprContext::Store => Type::None,
}
}
fn infer_attribute_expression(&mut self, attribute: &ast::ExprAttribute) -> Type {
let ast::ExprAttribute {
value,
attr,
range: _,
ctx,
} = attribute;
let value_ty = self.infer_expression(value);
let member_ty = value_ty
.member(&self.typing_context(), &Name::new(&attr.id))
.unwrap_or(Type::Unknown);
match ctx {
ExprContext::Load => member_ty,
ExprContext::Store | ExprContext::Del => Type::None,
ExprContext::Invalid => Type::Unknown,
}
}
fn infer_binary_expression(&mut self, binary: &ast::ExprBinOp) -> Type {
let ast::ExprBinOp {
left,
op,
right,
range: _,
} = binary;
let left_ty = self.infer_expression(left);
let right_ty = self.infer_expression(right);
match left_ty {
Type::Any => Type::Any,
Type::Unknown => Type::Unknown,
Type::IntLiteral(n) => {
match right_ty {
Type::IntLiteral(m) => {
match op {
ast::Operator::Add => n
.checked_add(m)
.map(Type::IntLiteral)
// TODO builtins.int
.unwrap_or(Type::Unknown),
ast::Operator::Sub => n
.checked_sub(m)
.map(Type::IntLiteral)
// TODO builtins.int
.unwrap_or(Type::Unknown),
ast::Operator::Mult => n
.checked_mul(m)
.map(Type::IntLiteral)
// TODO builtins.int
.unwrap_or(Type::Unknown),
ast::Operator::Div => n
.checked_div(m)
.map(Type::IntLiteral)
// TODO builtins.int
.unwrap_or(Type::Unknown),
ast::Operator::Mod => n
.checked_rem(m)
.map(Type::IntLiteral)
// TODO division by zero error
.unwrap_or(Type::Unknown),
_ => todo!("complete binop op support for IntLiteral"),
}
}
_ => todo!("complete binop right_ty support for IntLiteral"),
}
}
_ => todo!("complete binop support"),
}
}
fn infer_type_parameters(&mut self, _type_parameters: &TypeParams) {
todo!("Infer type parameters")
}
pub(super) fn finish(mut self) -> TypeInference {
let symbol_tys: IndexVec<_, _> = self
.index
.symbol_table(self.file_scope_id)
.symbol_ids()
.map(|symbol| self.local_definition_ty(symbol))
.collect();
self.types.symbol_tys = symbol_tys;
self.types.shrink_to_fit();
self.types
}
fn union_ty(&mut self, ty: UnionType) -> Type {
Type::Union(TypeId {
scope: self.scope,
scoped: self.types.union_types.push(ty),
})
}
fn function_ty(&mut self, ty: FunctionType) -> Type {
Type::Function(TypeId {
scope: self.scope,
scoped: self.types.function_types.push(ty),
})
}
fn class_ty(&mut self, ty: ClassType) -> Type {
Type::Class(TypeId {
scope: self.scope,
scoped: self.types.class_types.push(ty),
})
}
fn typing_context(&self) -> TypingContext {
TypingContext::scoped(self.db, self.scope, &self.types)
}
fn local_definition_ty(&mut self, symbol: ScopedSymbolId) -> Type {
let symbol = self.symbol_table.symbol(symbol);
let mut definitions = symbol
.definitions()
.iter()
.filter_map(|definition| self.definition_tys.get(definition).copied());
let Some(first) = definitions.next() else {
return Type::Unbound;
};
if let Some(second) = definitions.next() {
let context = self.typing_context();
let mut builder = UnionTypeBuilder::new(&context);
builder = builder.add(first).add(second);
for variant in definitions {
builder = builder.add(variant);
}
self.union_ty(builder.build())
} else {
first
}
}
}
#[cfg(test)]
mod tests {
use ruff_db::file_system::FileSystemPathBuf;
use ruff_db::vfs::system_path_to_file;
use crate::db::tests::TestDb;
use crate::module::resolver::{set_module_resolution_settings, ModuleResolutionSettings};
use crate::name::Name;
use crate::red_knot::types::{public_symbol_ty_by_name, Type, TypingContext};
fn setup_db() -> TestDb {
let mut db = TestDb::new();
set_module_resolution_settings(
&mut db,
ModuleResolutionSettings {
extra_paths: Vec::new(),
workspace_root: FileSystemPathBuf::from("/src"),
site_packages: None,
custom_typeshed: None,
},
);
db
}
fn assert_public_ty(db: &TestDb, file_name: &str, symbol_name: &str, expected: &str) {
let file = system_path_to_file(db, file_name).expect("Expected file to exist.");
let ty = public_symbol_ty_by_name(db, file, symbol_name).unwrap_or(Type::Unknown);
assert_eq!(ty.display(&TypingContext::global(db)).to_string(), expected);
}
#[test]
fn follow_import_to_class() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_files([
("src/a.py", "from b import C as D; E = D"),
("src/b.py", "class C: pass"),
])?;
assert_public_ty(&db, "src/a.py", "E", "Literal[C]");
Ok(())
}
#[test]
fn resolve_base_class_by_name() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_file(
"src/mod.py",
r#"
class Base:
pass
class Sub(Base):
pass"#,
)?;
let mod_file = system_path_to_file(&db, "src/mod.py").expect("Expected file to exist.");
let ty = public_symbol_ty_by_name(&db, mod_file, "Sub").expect("Symbol type to exist");
let Type::Class(class_id) = ty else {
panic!("Sub is not a Class")
};
let context = TypingContext::global(&db);
let base_names: Vec<_> = class_id
.lookup(&context)
.bases()
.iter()
.map(|base_ty| format!("{}", base_ty.display(&context)))
.collect();
assert_eq!(base_names, vec!["Literal[Base]"]);
Ok(())
}
#[test]
fn resolve_method() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_file(
"src/mod.py",
"
class C:
def f(self): pass
",
)?;
let mod_file = system_path_to_file(&db, "src/mod.py").unwrap();
let ty = public_symbol_ty_by_name(&db, mod_file, "C").unwrap();
let Type::Class(class_id) = ty else {
panic!("C is not a Class");
};
let context = TypingContext::global(&db);
let member_ty = class_id.class_member(&context, &Name::new("f"));
let Some(Type::Function(func_id)) = member_ty else {
panic!("C.f is not a Function");
};
let function_ty = func_id.lookup(&context);
assert_eq!(function_ty.name(), "f");
Ok(())
}
#[test]
fn resolve_module_member() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_files([
("src/a.py", "import b; D = b.C"),
("src/b.py", "class C: pass"),
])?;
assert_public_ty(&db, "src/a.py", "D", "Literal[C]");
Ok(())
}
#[test]
fn resolve_literal() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_file("src/a.py", "x = 1")?;
assert_public_ty(&db, "src/a.py", "x", "Literal[1]");
Ok(())
}
#[test]
fn resolve_union() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_file(
"src/a.py",
"
if flag:
x = 1
else:
x = 2
",
)?;
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]");
Ok(())
}
#[test]
fn literal_int_arithmetic() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_file(
"src/a.py",
"
a = 2 + 1
b = a - 4
c = a * b
d = c / 3
e = 5 % 3
",
)?;
assert_public_ty(&db, "src/a.py", "a", "Literal[3]");
assert_public_ty(&db, "src/a.py", "b", "Literal[-1]");
assert_public_ty(&db, "src/a.py", "c", "Literal[-3]");
assert_public_ty(&db, "src/a.py", "d", "Literal[-1]");
assert_public_ty(&db, "src/a.py", "e", "Literal[2]");
Ok(())
}
#[test]
fn walrus() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system()
.write_file("src/a.py", "x = (y := 1) + 1")?;
assert_public_ty(&db, "src/a.py", "x", "Literal[2]");
assert_public_ty(&db, "src/a.py", "y", "Literal[1]");
Ok(())
}
#[test]
fn ifexpr() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system()
.write_file("src/a.py", "x = 1 if flag else 2")?;
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]");
Ok(())
}
#[test]
fn ifexpr_walrus() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system().write_file(
"src/a.py",
"
y = z = 0
x = (y := 1) if flag else (z := 2)
a = y
b = z
",
)?;
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]");
assert_public_ty(&db, "src/a.py", "a", "Literal[0, 1]");
assert_public_ty(&db, "src/a.py", "b", "Literal[0, 2]");
Ok(())
}
#[test]
fn ifexpr_nested() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system()
.write_file("src/a.py", "x = 1 if flag else 2 if flag2 else 3")?;
assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2, 3]");
Ok(())
}
#[test]
fn none() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system()
.write_file("src/a.py", "x = 1 if flag else None")?;
assert_public_ty(&db, "src/a.py", "x", "Literal[1] | None");
Ok(())
}
}