[red-knot] intern types using Salsa (#12061)

Intern types using Salsa interning instead of in the `TypeInference`
result.

This eliminates the need for `TypingContext`, and also paves the way for
finer-grained type inference queries.
This commit is contained in:
Carl Meyer 2024-07-05 12:16:37 -07:00 committed by GitHub
parent 7b50061b43
commit 0e44235981
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 190 additions and 534 deletions

11
Cargo.lock generated
View file

@ -1532,6 +1532,15 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "ordermap"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab5a8e22be64dfa1123429350872e7be33594dbf5ae5212c90c5890e71966d1d"
dependencies = [
"indexmap",
]
[[package]]
name = "os_str_bytes"
version = "6.6.1"
@ -1902,7 +1911,7 @@ dependencies = [
"anyhow",
"bitflags 2.6.0",
"hashbrown 0.14.5",
"indexmap",
"ordermap",
"red_knot_module_resolver",
"ruff_db",
"ruff_index",

View file

@ -72,7 +72,6 @@ hashbrown = "0.14.3"
ignore = { version = "0.4.22" }
imara-diff = { version = "0.1.5" }
imperative = { version = "1.0.4" }
indexmap = { version = "2.2.6" }
indicatif = { version = "0.17.8" }
indoc = { version = "2.0.4" }
insta = { version = "1.35.1" }
@ -95,6 +94,7 @@ mimalloc = { version = "0.1.39" }
natord = { version = "1.0.9" }
notify = { version = "6.1.1" }
once_cell = { version = "1.19.0" }
ordermap = { version = "0.5.0" }
path-absolutize = { version = "3.1.1" }
path-slash = { version = "0.2.1" }
pathdiff = { version = "0.2.1" }

View file

@ -122,7 +122,6 @@ fn lint_unresolved_imports(context: &SemanticLintContext, import: AnyImportRef)
fn lint_bad_override(context: &SemanticLintContext, class: &ast::StmtClassDef) {
let semantic = &context.semantic;
let typing_context = semantic.typing_context();
// TODO we should have a special marker on the real typing module (from typeshed) so if you
// have your own "typing" module in your project, we don't consider it THE typing module (and
@ -150,17 +149,17 @@ fn lint_bad_override(context: &SemanticLintContext, class: &ast::StmtClassDef) {
return;
};
if ty.has_decorator(&typing_context, override_ty) {
let method_name = ty.name(&typing_context);
if class_ty
.inherited_class_member(&typing_context, method_name)
.is_none()
{
// TODO this shouldn't make direct use of the Db; see comment on SemanticModel::db
let db = semantic.db();
if ty.has_decorator(db, override_ty) {
let method_name = ty.name(db);
if class_ty.inherited_class_member(db, &method_name).is_none() {
// TODO should have a qualname() method to support nested classes
context.push_diagnostic(
format!(
"Method {}.{} is decorated with `typing.override` but does not override any base class method",
class_ty.name(&typing_context),
class_ty.name(db),
method_name,
));
}

View file

@ -18,7 +18,7 @@ ruff_python_ast = { workspace = true }
ruff_text_size = { workspace = true }
bitflags = { workspace = true }
indexmap = { workspace = true }
ordermap = { workspace = true }
salsa = { workspace = true }
tracing = { workspace = true }
rustc-hash = { workspace = true }

View file

@ -7,13 +7,19 @@ use red_knot_module_resolver::Db as ResolverDb;
use crate::semantic_index::definition::Definition;
use crate::semantic_index::symbol::{public_symbols_map, PublicSymbolId, ScopeId};
use crate::semantic_index::{root_scope, semantic_index, symbol_table};
use crate::types::{infer_types, public_symbol_ty};
use crate::types::{
infer_types, public_symbol_ty, ClassType, FunctionType, IntersectionType, UnionType,
};
#[salsa::jar(db=Db)]
pub struct Jar(
ScopeId<'_>,
PublicSymbolId<'_>,
Definition<'_>,
FunctionType<'_>,
ClassType<'_>,
UnionType<'_>,
IntersectionType<'_>,
symbol_table,
root_scope,
semantic_index,

View file

@ -12,4 +12,4 @@ pub mod semantic_index;
mod semantic_model;
pub mod types;
type FxIndexSet<V> = indexmap::set::IndexSet<V, BuildHasherDefault<FxHasher>>;
type FxOrderSet<V> = ordermap::set::OrderSet<V, BuildHasherDefault<FxHasher>>;

View file

@ -1,10 +0,0 @@
use std::hash::BuildHasherDefault;
use rustc_hash::FxHasher;
pub mod ast_node_ref;
mod node_key;
pub mod semantic_index;
pub mod types;
pub(crate) type FxIndexSet<V> = indexmap::set::IndexSet<V, BuildHasherDefault<FxHasher>>;

View file

@ -155,7 +155,6 @@ impl<'db> PublicSymbolsMap<'db> {
/// A cross-module identifier of a scope that can be used as a salsa query parameter.
#[salsa::tracked]
pub struct ScopeId<'db> {
#[allow(clippy::used_underscore_binding)]
#[id]
pub file: VfsFile,
#[id]

View file

@ -6,7 +6,7 @@ use ruff_python_ast::{Expr, ExpressionRef, StmtClassDef};
use crate::semantic_index::ast_ids::HasScopedAstId;
use crate::semantic_index::symbol::PublicSymbolId;
use crate::semantic_index::{public_symbol, semantic_index};
use crate::types::{infer_types, public_symbol_ty, Type, TypingContext};
use crate::types::{infer_types, public_symbol_ty, Type};
use crate::Db;
pub struct SemanticModel<'db> {
@ -19,6 +19,12 @@ impl<'db> SemanticModel<'db> {
Self { db, file }
}
// TODO we don't actually want to expose the Db directly to lint rules, but we need to find a
// solution for exposing information from types
pub fn db(&self) -> &dyn Db {
self.db
}
pub fn resolve_module(&self, module_name: ModuleName) -> Option<Module> {
resolve_module(self.db.upcast(), module_name)
}
@ -27,13 +33,9 @@ impl<'db> SemanticModel<'db> {
public_symbol(self.db, module.file(), symbol_name)
}
pub fn public_symbol_ty(&self, symbol: PublicSymbolId<'db>) -> Type<'db> {
pub fn public_symbol_ty(&self, symbol: PublicSymbolId<'db>) -> Type {
public_symbol_ty(self.db, symbol)
}
pub fn typing_context(&self) -> TypingContext<'db, '_> {
TypingContext::global(self.db)
}
}
pub trait HasTy {

View file

@ -1,13 +1,11 @@
use ruff_db::parsed::parsed_module;
use ruff_db::vfs::VfsFile;
use ruff_index::newtype_index;
use ruff_python_ast::name::Name;
use crate::semantic_index::symbol::{FileScopeId, NodeWithScopeKind, PublicSymbolId, ScopeId};
use crate::semantic_index::symbol::{NodeWithScopeKind, PublicSymbolId, ScopeId};
use crate::semantic_index::{public_symbol, root_scope, semantic_index, symbol_table};
use crate::types::infer::{TypeInference, TypeInferenceBuilder};
use crate::Db;
use crate::FxIndexSet;
use crate::{Db, FxOrderSet};
mod display;
mod infer;
@ -43,12 +41,12 @@ pub(crate) fn public_symbol_ty<'db>(db: &'db dyn Db, symbol: PublicSymbolId<'db>
let file = symbol.file(db);
let scope = root_scope(db, file);
// TODO switch to inferring just the definition(s), not the whole scope
let inference = infer_types(db, scope);
inference.symbol_ty(symbol.scoped_symbol_id(db))
}
/// Shorthand for [`public_symbol_ty()`] that takes a symbol name instead of a [`PublicSymbolId`].
#[allow(unused)]
/// Shorthand for `public_symbol_ty` that takes a symbol name instead of a [`PublicSymbolId`].
pub(crate) fn public_symbol_ty_by_name<'db>(
db: &'db dyn Db,
file: VfsFile,
@ -91,7 +89,7 @@ pub(crate) fn infer_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> TypeInfe
}
/// unique ID for a type
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
pub enum Type<'db> {
/// the dynamic type: a statically-unknown set of values
Any,
@ -105,15 +103,15 @@ pub enum Type<'db> {
/// the None object (TODO remove this in favor of Instance(types.NoneType)
None,
/// a specific function object
Function(TypeId<'db, ScopedFunctionTypeId>),
Function(FunctionType<'db>),
/// a specific module object
Module(TypeId<'db, ScopedModuleTypeId>),
Module(VfsFile),
/// a specific class object
Class(TypeId<'db, ScopedClassTypeId>),
Class(ClassType<'db>),
/// the set of Python objects with the given class in their __class__'s method resolution order
Instance(TypeId<'db, ScopedClassTypeId>),
Union(TypeId<'db, ScopedUnionTypeId>),
Intersection(TypeId<'db, ScopedIntersectionTypeId>),
Instance(ClassType<'db>),
Union(UnionType<'db>),
Intersection(IntersectionType<'db>),
IntLiteral(i64),
// TODO protocols, callable types, overloads, generics, type vars
}
@ -127,7 +125,7 @@ impl<'db> Type<'db> {
matches!(self, Type::Unknown)
}
pub fn member(&self, context: &TypingContext<'db, '_>, name: &Name) -> Option<Type<'db>> {
pub fn member(&self, db: &'db dyn Db, name: &Name) -> Option<Type<'db>> {
match self {
Type::Any => Some(Type::Any),
Type::Never => todo!("attribute lookup on Never type"),
@ -135,14 +133,13 @@ impl<'db> Type<'db> {
Type::Unbound => todo!("attribute lookup on Unbound type"),
Type::None => todo!("attribute lookup on None type"),
Type::Function(_) => todo!("attribute lookup on Function type"),
Type::Module(module) => module.member(context, name),
Type::Class(class) => class.class_member(context, name),
Type::Module(file) => public_symbol_ty_by_name(db, *file, name),
Type::Class(class) => class.class_member(db, name),
Type::Instance(_) => {
// TODO MRO? get_own_instance_member, get_instance_member
todo!("attribute lookup on Instance type")
}
Type::Union(union_id) => {
let _union = union_id.lookup(context);
Type::Union(_) => {
// TODO perform the get_member on each type in the union
// TODO return the union of those results
// TODO if any of those results is `None` then include Unknown in the result union
@ -161,155 +158,25 @@ impl<'db> Type<'db> {
}
}
/// ID that uniquely identifies a type in a program.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct TypeId<'db, L> {
/// The scope in which this type is defined or was created.
scope: ScopeId<'db>,
/// The type's local ID in its scope.
scoped: L,
}
impl<'db, Id> TypeId<'db, Id>
where
Id: Copy,
{
pub fn scope(&self) -> ScopeId<'db> {
self.scope
}
pub fn scoped_id(&self) -> Id {
self.scoped
}
/// Resolves the type ID to the actual type.
pub(crate) fn lookup<'a>(self, context: &'a TypingContext<'db, 'a>) -> &'a Id::Ty<'db>
where
Id: ScopedTypeId,
{
let types = context.types(self.scope);
self.scoped.lookup_scoped(types)
}
}
/// ID that uniquely identifies a type in a scope.
pub(crate) trait ScopedTypeId {
/// The type that this ID points to.
type Ty<'db>;
/// Looks up the type in `index`.
///
/// ## Panics
/// May panic if this type is from another scope than `index`, or might just return an invalid type.
fn lookup_scoped<'a, 'db>(self, index: &'a TypeInference<'db>) -> &'a Self::Ty<'db>;
}
/// ID uniquely identifying a function type in a `scope`.
#[newtype_index]
pub struct ScopedFunctionTypeId;
impl ScopedTypeId for ScopedFunctionTypeId {
type Ty<'db> = FunctionType<'db>;
fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> {
types.function_ty(self)
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct FunctionType<'a> {
#[salsa::interned]
pub struct FunctionType<'db> {
/// name of the function at definition
name: Name,
pub name: Name,
/// types of all decorators on this function
decorators: Vec<Type<'a>>,
decorators: Vec<Type<'db>>,
}
impl<'a> FunctionType<'a> {
fn name(&self) -> &str {
self.name.as_str()
}
#[allow(unused)]
pub(crate) fn decorators(&self) -> &[Type<'a>] {
self.decorators.as_slice()
impl<'db> FunctionType<'db> {
pub fn has_decorator(self, db: &dyn Db, decorator: Type<'_>) -> bool {
self.decorators(db).contains(&decorator)
}
}
impl<'db> TypeId<'db, ScopedFunctionTypeId> {
pub fn name<'a>(self, context: &'a TypingContext<'db, 'a>) -> &'a Name {
let function_ty = self.lookup(context);
&function_ty.name
}
pub fn has_decorator(self, context: &TypingContext, decorator: Type<'db>) -> bool {
let function_ty = self.lookup(context);
function_ty.decorators.contains(&decorator)
}
}
#[newtype_index]
pub struct ScopedClassTypeId;
impl ScopedTypeId for ScopedClassTypeId {
type Ty<'db> = ClassType<'db>;
fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> {
types.class_ty(self)
}
}
impl<'db> TypeId<'db, ScopedClassTypeId> {
pub fn name<'a>(self, context: &'a TypingContext<'db, 'a>) -> &'a Name {
let class_ty = self.lookup(context);
&class_ty.name
}
/// Returns the class member of this class named `name`.
///
/// The member resolves to a member of the class itself or any of its bases.
pub fn class_member(self, context: &TypingContext<'db, '_>, name: &Name) -> Option<Type<'db>> {
if let Some(member) = self.own_class_member(context, name) {
return Some(member);
}
self.inherited_class_member(context, name)
}
/// Returns the inferred type of the class member named `name`.
pub fn own_class_member(
self,
context: &TypingContext<'db, '_>,
name: &Name,
) -> Option<Type<'db>> {
let class = self.lookup(context);
let symbols = symbol_table(context.db, class.body_scope);
let symbol = symbols.symbol_id_by_name(name)?;
let types = context.types(class.body_scope);
Some(types.symbol_ty(symbol))
}
pub fn inherited_class_member(
self,
context: &TypingContext<'db, '_>,
name: &Name,
) -> Option<Type<'db>> {
let class = self.lookup(context);
for base in &class.bases {
if let Some(member) = base.member(context, name) {
return Some(member);
}
}
None
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
#[salsa::interned]
pub struct ClassType<'db> {
/// Name of the class at definition
name: Name,
pub name: Name,
/// Types of all class bases
bases: Vec<Type<'db>>,
@ -318,52 +185,62 @@ pub struct ClassType<'db> {
}
impl<'db> ClassType<'db> {
fn name(&self) -> &str {
self.name.as_str()
/// Returns the class member of this class named `name`.
///
/// The member resolves to a member of the class itself or any of its bases.
pub fn class_member(self, db: &'db dyn Db, name: &Name) -> Option<Type<'db>> {
if let Some(member) = self.own_class_member(db, name) {
return Some(member);
}
self.inherited_class_member(db, name)
}
#[allow(unused)]
pub(super) fn bases(&self) -> &'db [Type] {
self.bases.as_slice()
/// Returns the inferred type of the class member named `name`.
pub fn own_class_member(self, db: &'db dyn Db, name: &Name) -> Option<Type<'db>> {
let scope = self.body_scope(db);
let symbols = symbol_table(db, scope);
let symbol = symbols.symbol_id_by_name(name)?;
let types = infer_types(db, scope);
Some(types.symbol_ty(symbol))
}
pub fn inherited_class_member(self, db: &'db dyn Db, name: &Name) -> Option<Type<'db>> {
for base in self.bases(db) {
if let Some(member) = base.member(db, name) {
return Some(member);
}
}
None
}
}
#[newtype_index]
pub struct ScopedUnionTypeId;
impl ScopedTypeId for ScopedUnionTypeId {
type Ty<'db> = UnionType<'db>;
fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> {
types.union_ty(self)
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
#[salsa::interned]
pub struct UnionType<'db> {
// the union type includes values in any of these types
elements: FxIndexSet<Type<'db>>,
/// the union type includes values in any of these types
elements: FxOrderSet<Type<'db>>,
}
struct UnionTypeBuilder<'db, 'a> {
elements: FxIndexSet<Type<'db>>,
context: &'a TypingContext<'db, 'a>,
struct UnionTypeBuilder<'db> {
elements: FxOrderSet<Type<'db>>,
db: &'db dyn Db,
}
impl<'db, 'a> UnionTypeBuilder<'db, 'a> {
fn new(context: &'a TypingContext<'db, 'a>) -> Self {
impl<'db> UnionTypeBuilder<'db> {
fn new(db: &'db dyn Db) -> Self {
Self {
context,
elements: FxIndexSet::default(),
db,
elements: FxOrderSet::default(),
}
}
/// Adds a type to this union.
fn add(mut self, ty: Type<'db>) -> Self {
match ty {
Type::Union(union_id) => {
let union = union_id.lookup(self.context);
self.elements.extend(&union.elements);
Type::Union(union) => {
self.elements.extend(&union.elements(self.db));
}
_ => {
self.elements.insert(ty);
@ -374,20 +251,7 @@ impl<'db, 'a> UnionTypeBuilder<'db, 'a> {
}
fn build(self) -> UnionType<'db> {
UnionType {
elements: self.elements,
}
}
}
#[newtype_index]
pub struct ScopedIntersectionTypeId;
impl ScopedTypeId for ScopedIntersectionTypeId {
type Ty<'db> = IntersectionType<'db>;
fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> {
types.intersection_ty(self)
UnionType::new(self.db, self.elements)
}
}
@ -397,104 +261,12 @@ impl ScopedTypeId for ScopedIntersectionTypeId {
// case where a Not appears outside an intersection (unclear when that could even happen, but we'd
// have to represent it as a single-element intersection if it did) in exchange for better
// efficiency in the within-intersection case.
#[derive(Debug, PartialEq, Eq, Clone)]
#[salsa::interned]
pub struct IntersectionType<'db> {
// the intersection type includes only values in all of these types
positive: FxIndexSet<Type<'db>>,
positive: FxOrderSet<Type<'db>>,
// the intersection type does not include any value in any of these types
negative: FxIndexSet<Type<'db>>,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct ScopedModuleTypeId;
impl ScopedTypeId for ScopedModuleTypeId {
type Ty<'db> = ModuleType;
fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> {
types.module_ty()
}
}
impl<'db> TypeId<'db, ScopedModuleTypeId> {
fn member(self, context: &TypingContext<'db, '_>, name: &Name) -> Option<Type<'db>> {
context.public_symbol_ty(self.scope.file(context.db), name)
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct ModuleType {
file: VfsFile,
}
/// Context in which to resolve types.
///
/// This abstraction is necessary to support a uniform API that can be used
/// while in the process of building the type inference structure for a scope
/// but also when all types should be resolved by querying the db.
pub struct TypingContext<'db, 'inference> {
db: &'db dyn Db,
/// The Local type inference scope that is in the process of being built.
///
/// Bypass the `db` when resolving the types for this scope.
local: Option<(ScopeId<'db>, &'inference TypeInference<'db>)>,
}
impl<'db, 'inference> TypingContext<'db, 'inference> {
/// Creates a context that resolves all types by querying the db.
#[allow(unused)]
pub(super) fn global(db: &'db dyn Db) -> Self {
Self { db, local: None }
}
/// Creates a context that by-passes the `db` when resolving types from `scope_id` and instead uses `types`.
fn scoped(
db: &'db dyn Db,
scope_id: ScopeId<'db>,
types: &'inference TypeInference<'db>,
) -> Self {
Self {
db,
local: Some((scope_id, types)),
}
}
/// Returns the [`TypeInference`] results (not guaranteed to be complete) for `scope_id`.
fn types(&self, scope_id: ScopeId<'db>) -> &'inference TypeInference<'db> {
if let Some((scope, local_types)) = self.local {
if scope == scope_id {
return local_types;
}
}
infer_types(self.db, scope_id)
}
fn module_ty(&self, file: VfsFile) -> Type<'db> {
let scope = root_scope(self.db, file);
Type::Module(TypeId {
scope,
scoped: ScopedModuleTypeId,
})
}
/// Resolves the public type of a symbol named `name` defined in `file`.
///
/// This function calls [`public_symbol_ty`] if the local scope isn't the module scope of `file`.
/// It otherwise tries to resolve the symbol type locally.
fn public_symbol_ty(&self, file: VfsFile, name: &Name) -> Option<Type<'db>> {
let symbol = public_symbol(self.db, file, name)?;
if let Some((scope, local_types)) = self.local {
if scope.file_scope_id(self.db) == FileScopeId::root() && scope.file(self.db) == file {
return Some(local_types.symbol_ty(symbol.scoped_symbol_id(self.db)));
}
}
Some(public_symbol_ty(self.db, symbol))
}
negative: FxOrderSet<Type<'db>>,
}
#[cfg(test)]
@ -508,7 +280,7 @@ mod tests {
assert_will_not_run_function_query, assert_will_run_function_query, TestDb,
};
use crate::semantic_index::root_scope;
use crate::types::{infer_types, public_symbol_ty_by_name, TypingContext};
use crate::types::{infer_types, public_symbol_ty_by_name};
use crate::{HasTy, SemanticModel};
fn setup_db() -> TestDb {
@ -540,10 +312,7 @@ mod tests {
let literal_ty = statement.value.ty(&model);
assert_eq!(
format!("{}", literal_ty.display(&TypingContext::global(&db))),
"Literal[10]"
);
assert_eq!(format!("{}", literal_ty.display(&db)), "Literal[10]");
Ok(())
}
@ -560,10 +329,7 @@ mod tests {
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
assert_eq!(x_ty.display(&db).to_string(), "Literal[10]");
// Change `x` to a different value
db.memory_file_system()
@ -577,10 +343,7 @@ mod tests {
db.clear_salsa_events();
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty_2.display(&TypingContext::global(&db)).to_string(),
"Literal[20]"
);
assert_eq!(x_ty_2.display(&db).to_string(), "Literal[20]");
let events = db.take_salsa_events();
@ -607,10 +370,7 @@ mod tests {
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
assert_eq!(x_ty.display(&db).to_string(), "Literal[10]");
db.memory_file_system()
.write_file("/src/foo.py", "x = 10\ndef foo(): pass")?;
@ -624,10 +384,7 @@ mod tests {
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty_2.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]");
let events = db.take_salsa_events();
@ -655,10 +412,7 @@ mod tests {
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
assert_eq!(x_ty.display(&db).to_string(), "Literal[10]");
db.memory_file_system()
.write_file("/src/foo.py", "x = 10\ny = 30")?;
@ -672,10 +426,7 @@ mod tests {
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(
x_ty_2.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]");
let events = db.take_salsa_events();

View file

@ -2,18 +2,19 @@
use std::fmt::{Display, Formatter};
use crate::types::{IntersectionType, Type, TypingContext, UnionType};
use crate::types::{IntersectionType, Type, UnionType};
use crate::Db;
impl Type<'_> {
pub fn display<'a>(&'a self, context: &'a TypingContext) -> DisplayType<'a> {
DisplayType { ty: self, context }
impl<'db> Type<'db> {
pub fn display(&'db self, db: &'db dyn Db) -> DisplayType<'db> {
DisplayType { ty: self, db }
}
}
#[derive(Copy, Clone)]
pub struct DisplayType<'a> {
ty: &'a Type<'a>,
context: &'a TypingContext<'a, 'a>,
pub struct DisplayType<'db> {
ty: &'db Type<'db>,
db: &'db dyn Db,
}
impl Display for DisplayType<'_> {
@ -24,42 +25,19 @@ impl Display for DisplayType<'_> {
Type::Unknown => f.write_str("Unknown"),
Type::Unbound => f.write_str("Unbound"),
Type::None => f.write_str("None"),
Type::Module(module_id) => {
write!(
f,
"<module '{:?}'>",
module_id
.scope
.file(self.context.db)
.path(self.context.db.upcast())
)
Type::Module(file) => {
write!(f, "<module '{:?}'>", file.path(self.db.upcast()))
}
// TODO functions and classes should display using a fully qualified name
Type::Class(class_id) => {
let class = class_id.lookup(self.context);
Type::Class(class) => {
f.write_str("Literal[")?;
f.write_str(class.name())?;
f.write_str(&class.name(self.db))?;
f.write_str("]")
}
Type::Instance(class_id) => {
let class = class_id.lookup(self.context);
f.write_str(class.name())
}
Type::Function(function_id) => {
let function = function_id.lookup(self.context);
f.write_str(function.name())
}
Type::Union(union_id) => {
let union = union_id.lookup(self.context);
union.display(self.context).fmt(f)
}
Type::Intersection(intersection_id) => {
let intersection = intersection_id.lookup(self.context);
intersection.display(self.context).fmt(f)
}
Type::Instance(class) => f.write_str(&class.name(self.db)),
Type::Function(function) => f.write_str(&function.name(self.db)),
Type::Union(union) => union.display(self.db).fmt(f),
Type::Intersection(intersection) => intersection.display(self.db).fmt(f),
Type::IntLiteral(n) => write!(f, "Literal[{n}]"),
}
}
@ -71,15 +49,15 @@ impl std::fmt::Debug for DisplayType<'_> {
}
}
impl UnionType<'_> {
fn display<'a>(&'a self, context: &'a TypingContext<'a, 'a>) -> DisplayUnionType<'a> {
DisplayUnionType { context, ty: self }
impl<'db> UnionType<'db> {
fn display(&'db self, db: &'db dyn Db) -> DisplayUnionType<'db> {
DisplayUnionType { db, ty: self }
}
}
struct DisplayUnionType<'a> {
ty: &'a UnionType<'a>,
context: &'a TypingContext<'a, 'a>,
struct DisplayUnionType<'db> {
ty: &'db UnionType<'db>,
db: &'db dyn Db,
}
impl Display for DisplayUnionType<'_> {
@ -87,7 +65,7 @@ impl Display for DisplayUnionType<'_> {
let union = self.ty;
let (int_literals, other_types): (Vec<Type>, Vec<Type>) = union
.elements
.elements(self.db)
.iter()
.copied()
.partition(|ty| matches!(ty, Type::IntLiteral(_)));
@ -121,7 +99,7 @@ impl Display for DisplayUnionType<'_> {
f.write_str(" | ")?;
};
first = false;
write!(f, "{}", ty.display(self.context))?;
write!(f, "{}", ty.display(self.db))?;
}
Ok(())
@ -134,15 +112,15 @@ impl std::fmt::Debug for DisplayUnionType<'_> {
}
}
impl IntersectionType<'_> {
fn display<'a>(&'a self, context: &'a TypingContext<'a, 'a>) -> DisplayIntersectionType<'a> {
DisplayIntersectionType { ty: self, context }
impl<'db> IntersectionType<'db> {
fn display(&'db self, db: &'db dyn Db) -> DisplayIntersectionType<'db> {
DisplayIntersectionType { db, ty: self }
}
}
struct DisplayIntersectionType<'a> {
ty: &'a IntersectionType<'a>,
context: &'a TypingContext<'a, 'a>,
struct DisplayIntersectionType<'db> {
ty: &'db IntersectionType<'db>,
db: &'db dyn Db,
}
impl Display for DisplayIntersectionType<'_> {
@ -150,10 +128,10 @@ impl Display for DisplayIntersectionType<'_> {
let mut first = true;
for (neg, ty) in self
.ty
.positive
.positive(self.db)
.iter()
.map(|ty| (false, ty))
.chain(self.ty.negative.iter().map(|ty| (true, ty)))
.chain(self.ty.negative(self.db).iter().map(|ty| (true, ty)))
{
if !first {
f.write_str(" & ")?;
@ -162,7 +140,7 @@ impl Display for DisplayIntersectionType<'_> {
if neg {
f.write_str("~")?;
};
write!(f, "{}", ty.display(self.context))?;
write!(f, "{}", ty.display(self.db))?;
}
Ok(())
}

View file

@ -2,8 +2,7 @@ use rustc_hash::FxHashMap;
use std::borrow::Cow;
use std::sync::Arc;
use red_knot_module_resolver::resolve_module;
use red_knot_module_resolver::ModuleName;
use red_knot_module_resolver::{resolve_module, ModuleName};
use ruff_db::vfs::VfsFile;
use ruff_index::IndexVec;
use ruff_python_ast as ast;
@ -15,81 +14,40 @@ use crate::semantic_index::symbol::{
FileScopeId, NodeWithScopeRef, ScopeId, ScopedSymbolId, SymbolTable,
};
use crate::semantic_index::{symbol_table, SemanticIndex};
use crate::types::{
infer_types, ClassType, FunctionType, IntersectionType, ModuleType, ScopedClassTypeId,
ScopedFunctionTypeId, ScopedIntersectionTypeId, ScopedUnionTypeId, Type, TypeId, TypingContext,
UnionType, UnionTypeBuilder,
};
use crate::types::{infer_types, ClassType, FunctionType, Name, Type, UnionTypeBuilder};
use crate::Db;
/// The inferred types for a single scope.
#[derive(Debug, Eq, PartialEq, Default, Clone)]
pub(crate) struct TypeInference<'db> {
/// The type of the module if the scope is a module scope.
module_type: Option<ModuleType>,
/// The types of the defined classes in this scope.
class_types: IndexVec<ScopedClassTypeId, ClassType<'db>>,
/// The types of the defined functions in this scope.
function_types: IndexVec<ScopedFunctionTypeId, FunctionType<'db>>,
union_types: IndexVec<ScopedUnionTypeId, UnionType<'db>>,
intersection_types: IndexVec<ScopedIntersectionTypeId, IntersectionType<'db>>,
/// The types of every expression in this scope.
expression_tys: IndexVec<ScopedExpressionId, Type<'db>>,
expressions: IndexVec<ScopedExpressionId, Type<'db>>,
/// The public types of every symbol in this scope.
symbol_tys: IndexVec<ScopedSymbolId, Type<'db>>,
symbols: IndexVec<ScopedSymbolId, Type<'db>>,
/// The type of a definition.
definition_tys: FxHashMap<Definition<'db>, Type<'db>>,
definitions: FxHashMap<Definition<'db>, Type<'db>>,
}
impl<'db> TypeInference<'db> {
#[allow(unused)]
pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> {
self.expression_tys[expression]
self.expressions[expression]
}
pub(super) fn symbol_ty(&self, symbol: ScopedSymbolId) -> Type<'db> {
self.symbol_tys[symbol]
self.symbols[symbol]
}
pub(super) fn module_ty(&self) -> &ModuleType {
self.module_type.as_ref().unwrap()
}
pub(super) fn class_ty(&self, id: ScopedClassTypeId) -> &ClassType<'db> {
&self.class_types[id]
}
pub(super) fn function_ty(&self, id: ScopedFunctionTypeId) -> &FunctionType<'db> {
&self.function_types[id]
}
pub(super) fn union_ty(&self, id: ScopedUnionTypeId) -> &UnionType<'db> {
&self.union_types[id]
}
pub(super) fn intersection_ty(&self, id: ScopedIntersectionTypeId) -> &IntersectionType<'db> {
&self.intersection_types[id]
}
pub(crate) fn definition_ty(&self, definition: Definition) -> Type<'db> {
self.definition_tys[&definition]
pub(crate) fn definition_ty(&self, definition: Definition<'db>) -> Type<'db> {
self.definitions[&definition]
}
fn shrink_to_fit(&mut self) {
self.class_types.shrink_to_fit();
self.function_types.shrink_to_fit();
self.union_types.shrink_to_fit();
self.intersection_types.shrink_to_fit();
self.expression_tys.shrink_to_fit();
self.symbol_tys.shrink_to_fit();
self.definition_tys.shrink_to_fit();
self.expressions.shrink_to_fit();
self.symbols.shrink_to_fit();
self.definitions.shrink_to_fit();
}
}
@ -99,7 +57,6 @@ pub(super) struct TypeInferenceBuilder<'db> {
// Cached lookups
index: &'db SemanticIndex<'db>,
scope: ScopeId<'db>,
file_scope_id: FileScopeId,
file_id: VfsFile,
symbol_table: Arc<SymbolTable<'db>>,
@ -123,7 +80,6 @@ impl<'db> TypeInferenceBuilder<'db> {
index,
file_scope_id,
file_id: file,
scope,
symbol_table,
db,
@ -205,13 +161,11 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_expression(return_ty);
}
let function_ty = self.function_ty(FunctionType {
name: name.id.clone(),
decorators: decorator_tys,
});
let function_ty =
Type::Function(FunctionType::new(self.db, name.id.clone(), decorator_tys));
let definition = self.index.definition(function);
self.types.definition_tys.insert(definition, function_ty);
self.types.definitions.insert(definition, function_ty);
}
fn infer_class_definition_statement(&mut self, class: &ast::StmtClassDef) {
@ -233,16 +187,15 @@ impl<'db> TypeInferenceBuilder<'db> {
.map(|arguments| self.infer_arguments(arguments))
.unwrap_or(Vec::new());
let body_scope = self.index.node_scope(NodeWithScopeRef::Class(class));
let body_scope = self
.index
.node_scope(NodeWithScopeRef::Class(class))
.to_scope_id(self.db, self.file_id);
let class_ty = self.class_ty(ClassType {
name: name.id.clone(),
bases,
body_scope: body_scope.to_scope_id(self.db, self.file_id),
});
let class_ty = Type::Class(ClassType::new(self.db, name.id.clone(), bases, body_scope));
let definition = self.index.definition(class);
self.types.definition_tys.insert(definition, class_ty);
self.types.definitions.insert(definition, class_ty);
}
fn infer_if_statement(&mut self, if_statement: &ast::StmtIf) {
@ -283,7 +236,7 @@ impl<'db> TypeInferenceBuilder<'db> {
for target in targets {
self.infer_expression(target);
self.types.definition_tys.insert(
self.types.definitions.insert(
self.index.definition(DefinitionNodeRef::Target(target)),
value_ty,
);
@ -306,7 +259,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let annotation_ty = self.infer_expression(annotation);
self.infer_expression(target);
self.types.definition_tys.insert(
self.types.definitions.insert(
self.index.definition(DefinitionNodeRef::Target(target)),
annotation_ty,
);
@ -341,12 +294,12 @@ impl<'db> TypeInferenceBuilder<'db> {
let module_name = ModuleName::new(&name.id);
let module = module_name.and_then(|name| resolve_module(self.db.upcast(), name));
let module_ty = module
.map(|module| self.typing_context().module_ty(module.file()))
.map(|module| Type::Module(module.file()))
.unwrap_or(Type::Unknown);
let definition = self.index.definition(alias);
self.types.definition_tys.insert(definition, module_ty);
self.types.definitions.insert(definition, module_ty);
}
}
@ -363,7 +316,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let module =
module_name.and_then(|module_name| resolve_module(self.db.upcast(), module_name));
let module_ty = module
.map(|module| self.typing_context().module_ty(module.file()))
.map(|module| Type::Module(module.file()))
.unwrap_or(Type::Unknown);
for alias in names {
@ -374,11 +327,11 @@ impl<'db> TypeInferenceBuilder<'db> {
} = alias;
let ty = module_ty
.member(&self.typing_context(), &name.id)
.member(self.db, &Name::new(&name.id))
.unwrap_or(Type::Unknown);
let definition = self.index.definition(alias);
self.types.definition_tys.insert(definition, ty);
self.types.definitions.insert(definition, ty);
}
}
@ -425,7 +378,7 @@ impl<'db> TypeInferenceBuilder<'db> {
_ => todo!("expression type resolution for {:?}", expression),
};
self.types.expression_tys.push(ty);
self.types.expressions.push(ty);
ty
}
@ -455,7 +408,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_expression(target);
self.types
.definition_tys
.definitions
.insert(self.index.definition(named), value_ty);
value_ty
@ -475,12 +428,12 @@ impl<'db> TypeInferenceBuilder<'db> {
let body_ty = self.infer_expression(body);
let orelse_ty = self.infer_expression(orelse);
let union = UnionTypeBuilder::new(&self.typing_context())
let union = UnionTypeBuilder::new(self.db)
.add(body_ty)
.add(orelse_ty)
.build();
self.union_ty(union)
Type::Union(union)
}
fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type<'db> {
@ -537,7 +490,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let value_ty = self.infer_expression(value);
let member_ty = value_ty
.member(&self.typing_context(), &attr.id)
.member(self.db, &Name::new(&attr.id))
.unwrap_or(Type::Unknown);
match ctx {
@ -612,57 +565,31 @@ impl<'db> TypeInferenceBuilder<'db> {
.map(|symbol| self.local_definition_ty(symbol))
.collect();
self.types.symbol_tys = symbol_tys;
self.types.symbols = symbol_tys;
self.types.shrink_to_fit();
self.types
}
fn union_ty(&mut self, ty: UnionType<'db>) -> Type<'db> {
Type::Union(TypeId {
scope: self.scope,
scoped: self.types.union_types.push(ty),
})
}
fn function_ty(&mut self, ty: FunctionType<'db>) -> Type<'db> {
Type::Function(TypeId {
scope: self.scope,
scoped: self.types.function_types.push(ty),
})
}
fn class_ty(&mut self, ty: ClassType<'db>) -> Type<'db> {
Type::Class(TypeId {
scope: self.scope,
scoped: self.types.class_types.push(ty),
})
}
fn typing_context(&self) -> TypingContext<'db, '_> {
TypingContext::scoped(self.db, self.scope, &self.types)
}
fn local_definition_ty(&mut self, symbol: ScopedSymbolId) -> Type<'db> {
let symbol = self.symbol_table.symbol(symbol);
let mut definitions = symbol
.definitions()
.iter()
.filter_map(|definition| self.types.definition_tys.get(definition).copied());
.filter_map(|definition| self.types.definitions.get(definition).copied());
let Some(first) = definitions.next() else {
return Type::Unbound;
};
if let Some(second) = definitions.next() {
let context = self.typing_context();
let mut builder = UnionTypeBuilder::new(&context);
let mut builder = UnionTypeBuilder::new(self.db);
builder = builder.add(first).add(second);
for variant in definitions {
builder = builder.add(variant);
}
self.union_ty(builder.build())
Type::Union(builder.build())
} else {
first
}
@ -677,7 +604,7 @@ mod tests {
use ruff_python_ast::name::Name;
use crate::db::tests::TestDb;
use crate::types::{public_symbol_ty_by_name, Type, TypingContext};
use crate::types::{public_symbol_ty_by_name, Type};
fn setup_db() -> TestDb {
let mut db = TestDb::new();
@ -699,7 +626,7 @@ mod tests {
let file = system_path_to_file(db, file_name).expect("Expected file to exist.");
let ty = public_symbol_ty_by_name(db, file, symbol_name).unwrap_or(Type::Unknown);
assert_eq!(ty.display(&TypingContext::global(db)).to_string(), expected);
assert_eq!(ty.display(db).to_string(), expected);
}
#[test]
@ -733,17 +660,14 @@ class Sub(Base):
let mod_file = system_path_to_file(&db, "src/mod.py").expect("Expected file to exist.");
let ty = public_symbol_ty_by_name(&db, mod_file, "Sub").expect("Symbol type to exist");
let Type::Class(class_id) = ty else {
let Type::Class(class) = ty else {
panic!("Sub is not a Class")
};
let context = TypingContext::global(&db);
let base_names: Vec<_> = class_id
.lookup(&context)
.bases()
let base_names: Vec<_> = class
.bases(&db)
.iter()
.map(|base_ty| format!("{}", base_ty.display(&context)))
.map(|base_ty| format!("{}", base_ty.display(&db)))
.collect();
assert_eq!(base_names, vec!["Literal[Base]"]);
@ -770,15 +694,13 @@ class C:
panic!("C is not a Class");
};
let context = TypingContext::global(&db);
let member_ty = class_id.class_member(&context, &Name::new_static("f"));
let member_ty = class_id.class_member(&db, &Name::new_static("f"));
let Some(Type::Function(func_id)) = member_ty else {
let Some(Type::Function(func)) = member_ty else {
panic!("C.f is not a Function");
};
let function_ty = func_id.lookup(&context);
assert_eq!(function_ty.name(), "f");
assert_eq!(func.name(&db), "f");
Ok(())
}