[red-knot] more ergonomic and efficient handling of known builtin classes (#13615)

This commit is contained in:
Simon 2024-10-05 19:03:46 +02:00 committed by GitHub
parent 7c5a7d909c
commit 1c2cafc101
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 208 additions and 74 deletions

View file

@ -14,7 +14,7 @@ use crate::stdlib::{
builtins_symbol_ty, types_symbol_ty, typeshed_symbol_ty, typing_extensions_symbol_ty,
};
use crate::types::narrow::narrowing_constraint;
use crate::{Db, FxOrderSet};
use crate::{Db, FxOrderSet, Module};
pub(crate) use self::builder::{IntersectionBuilder, UnionBuilder};
pub(crate) use self::diagnostic::TypeCheckDiagnostics;
@ -385,14 +385,6 @@ impl<'db> Type<'db> {
}
}
pub fn builtin_str_instance(db: &'db dyn Db) -> Self {
builtins_symbol_ty(db, "str").to_instance(db)
}
pub fn builtin_int_instance(db: &'db dyn Db) -> Self {
builtins_symbol_ty(db, "int").to_instance(db)
}
pub fn is_stdlib_symbol(&self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
match self {
Type::Class(class) => class.is_stdlib_symbol(db, module_name, name),
@ -423,19 +415,17 @@ impl<'db> Type<'db> {
(_, Type::Unknown | Type::Any | Type::Todo) => false,
(Type::Never, _) => true,
(_, Type::Never) => false,
(Type::IntLiteral(_), Type::Instance(class))
if class.is_stdlib_symbol(db, "builtins", "int") =>
{
(Type::IntLiteral(_), Type::Instance(class)) if class.is_known(db, KnownClass::Int) => {
true
}
(Type::StringLiteral(_), Type::LiteralString) => true,
(Type::StringLiteral(_) | Type::LiteralString, Type::Instance(class))
if class.is_stdlib_symbol(db, "builtins", "str") =>
if class.is_known(db, KnownClass::Str) =>
{
true
}
(Type::BytesLiteral(_), Type::Instance(class))
if class.is_stdlib_symbol(db, "builtins", "bytes") =>
if class.is_known(db, KnownClass::Bytes) =>
{
true
}
@ -443,8 +433,8 @@ impl<'db> Type<'db> {
.elements(db)
.iter()
.any(|&elem_ty| ty.is_subtype_of(db, elem_ty)),
(_, Type::Instance(class)) if class.is_stdlib_symbol(db, "builtins", "object") => true,
(Type::Instance(class), _) if class.is_stdlib_symbol(db, "builtins", "object") => false,
(_, Type::Instance(class)) if class.is_known(db, KnownClass::Object) => true,
(Type::Instance(class), _) if class.is_known(db, KnownClass::Object) => false,
// TODO
_ => false,
}
@ -600,9 +590,9 @@ impl<'db> Type<'db> {
fn call(self, db: &'db dyn Db, arg_types: &[Type<'db>]) -> CallOutcome<'db> {
match self {
// TODO validate typed call arguments vs callable signature
Type::Function(function_type) => match function_type.kind(db) {
FunctionKind::Ordinary => CallOutcome::callable(function_type.return_type(db)),
FunctionKind::RevealType => CallOutcome::revealed(
Type::Function(function_type) => match function_type.known(db) {
None => CallOutcome::callable(function_type.return_type(db)),
Some(KnownFunction::RevealType) => CallOutcome::revealed(
function_type.return_type(db),
*arg_types.first().unwrap_or(&Type::Unknown),
),
@ -610,16 +600,15 @@ impl<'db> Type<'db> {
// TODO annotated return type on `__new__` or metaclass `__call__`
Type::Class(class) => {
// If the class is the builtin-bool class (for example `bool(1)`), we try to return
// the specific truthiness value of the input arg, `Literal[True]` for the example above.
let is_bool = class.is_stdlib_symbol(db, "builtins", "bool");
CallOutcome::callable(if is_bool {
arg_types
CallOutcome::callable(match class.known(db) {
// If the class is the builtin-bool class (for example `bool(1)`), we try to
// return the specific truthiness value of the input arg, `Literal[True]` for
// the example above.
Some(KnownClass::Bool) => arg_types
.first()
.map(|arg| arg.bool(db).into_type(db))
.unwrap_or(Type::BooleanLiteral(false))
} else {
Type::Instance(class)
.unwrap_or(Type::BooleanLiteral(false)),
_ => Type::Instance(class),
})
}
@ -714,7 +703,7 @@ impl<'db> Type<'db> {
let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__");
dunder_get_item_method
.call(db, &[self, builtins_symbol_ty(db, "int").to_instance(db)])
.call(db, &[self, KnownClass::Int.to_instance(db)])
.return_ty(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
@ -758,17 +747,17 @@ impl<'db> Type<'db> {
Type::Never => Type::Never,
Type::Instance(class) => Type::Class(*class),
Type::Union(union) => union.map(db, |ty| ty.to_meta_type(db)),
Type::BooleanLiteral(_) => builtins_symbol_ty(db, "bool"),
Type::BytesLiteral(_) => builtins_symbol_ty(db, "bytes"),
Type::IntLiteral(_) => builtins_symbol_ty(db, "int"),
Type::Function(_) => types_symbol_ty(db, "FunctionType"),
Type::Module(_) => types_symbol_ty(db, "ModuleType"),
Type::Tuple(_) => builtins_symbol_ty(db, "tuple"),
Type::None => typeshed_symbol_ty(db, "NoneType"),
Type::BooleanLiteral(_) => KnownClass::Bool.to_class(db),
Type::BytesLiteral(_) => KnownClass::Bytes.to_class(db),
Type::IntLiteral(_) => KnownClass::Int.to_class(db),
Type::Function(_) => KnownClass::FunctionType.to_class(db),
Type::Module(_) => KnownClass::ModuleType.to_class(db),
Type::Tuple(_) => KnownClass::Tuple.to_class(db),
Type::None => KnownClass::NoneType.to_class(db),
// TODO not accurate if there's a custom metaclass...
Type::Class(_) => builtins_symbol_ty(db, "type"),
Type::Class(_) => KnownClass::Type.to_class(db),
// TODO can we do better here? `type[LiteralString]`?
Type::StringLiteral(_) | Type::LiteralString => builtins_symbol_ty(db, "str"),
Type::StringLiteral(_) | Type::LiteralString => KnownClass::Str.to_class(db),
// TODO: `type[Any]`?
Type::Any => Type::Todo,
// TODO: `type[Unknown]`?
@ -790,7 +779,7 @@ impl<'db> Type<'db> {
Type::IntLiteral(_) | Type::BooleanLiteral(_) => self.repr(db),
Type::StringLiteral(_) | Type::LiteralString => *self,
// TODO: handle more complex types
_ => Type::builtin_str_instance(db),
_ => KnownClass::Str.to_instance(db),
}
}
@ -813,7 +802,7 @@ impl<'db> Type<'db> {
})),
Type::LiteralString => Type::LiteralString,
// TODO: handle more complex types
_ => Type::builtin_str_instance(db),
_ => KnownClass::Str.to_instance(db),
}
}
}
@ -824,6 +813,133 @@ impl<'db> From<&Type<'db>> for Type<'db> {
}
}
/// Non-exhaustive enumeration of known classes (e.g. `builtins.int`, `typing.Any`, ...) to allow
/// for easier syntax when interacting with very common classes.
///
/// Feel free to expand this enum if you ever find yourself using the same class in multiple
/// places.
/// Note: good candidates are any classes in `[crate::stdlib::CoreStdlibModule]`
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum KnownClass {
// To figure out where an stdlib symbol is defined, you can go into `crates/red_knot_vendored`
// and grep for the symbol name in any `.pyi` file.
// Builtins
Bool,
Object,
Bytes,
Type,
Int,
Float,
Str,
List,
Tuple,
Set,
Dict,
// Types
ModuleType,
FunctionType,
// Typeshed
NoneType, // Part of `types` for Python >= 3.10
}
impl<'db> KnownClass {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Bool => "bool",
Self::Object => "object",
Self::Bytes => "bytes",
Self::Tuple => "tuple",
Self::Int => "int",
Self::Float => "float",
Self::Str => "str",
Self::Set => "set",
Self::Dict => "dict",
Self::List => "list",
Self::Type => "type",
Self::ModuleType => "ModuleType",
Self::FunctionType => "FunctionType",
Self::NoneType => "NoneType",
}
}
pub fn to_instance(&self, db: &'db dyn Db) -> Type<'db> {
self.to_class(db).to_instance(db)
}
pub fn to_class(&self, db: &'db dyn Db) -> Type<'db> {
match self {
Self::Bool
| Self::Object
| Self::Bytes
| Self::Type
| Self::Int
| Self::Float
| Self::Str
| Self::List
| Self::Tuple
| Self::Set
| Self::Dict => builtins_symbol_ty(db, self.as_str()),
Self::ModuleType | Self::FunctionType => types_symbol_ty(db, self.as_str()),
Self::NoneType => typeshed_symbol_ty(db, self.as_str()),
}
}
pub fn maybe_from_module(module: &Module, class_name: &str) -> Option<Self> {
let candidate = Self::from_name(class_name)?;
if candidate.check_module(module) {
Some(candidate)
} else {
None
}
}
fn from_name(name: &str) -> Option<Self> {
// Note: if this becomes hard to maintain (as rust can't ensure at compile time that all
// variants of `Self` are covered), we might use a macro (in-house or dependency)
// See: https://stackoverflow.com/q/39070244
match name {
"bool" => Some(Self::Bool),
"object" => Some(Self::Object),
"bytes" => Some(Self::Bytes),
"tuple" => Some(Self::Tuple),
"type" => Some(Self::Type),
"int" => Some(Self::Int),
"float" => Some(Self::Float),
"str" => Some(Self::Str),
"set" => Some(Self::Set),
"dict" => Some(Self::Dict),
"list" => Some(Self::List),
"NoneType" => Some(Self::NoneType),
"ModuleType" => Some(Self::ModuleType),
"FunctionType" => Some(Self::FunctionType),
_ => None,
}
}
/// Private method checking if known class can be defined in the given module.
fn check_module(self, module: &Module) -> bool {
if !module.search_path().is_standard_library() {
return false;
}
match self {
Self::Bool
| Self::Object
| Self::Bytes
| Self::Type
| Self::Int
| Self::Float
| Self::Str
| Self::List
| Self::Tuple
| Self::Set
| Self::Dict => module.name() == "builtins",
Self::ModuleType | Self::FunctionType => module.name() == "types",
Self::NoneType => matches!(module.name().as_str(), "_typeshed" | "types"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum CallOutcome<'db> {
Callable {
@ -1128,7 +1244,7 @@ impl Truthiness {
match self {
Self::AlwaysTrue => Type::BooleanLiteral(true),
Self::AlwaysFalse => Type::BooleanLiteral(false),
Self::Ambiguous => builtins_symbol_ty(db, "bool").to_instance(db),
Self::Ambiguous => KnownClass::Bool.to_instance(db),
}
}
}
@ -1150,7 +1266,7 @@ pub struct FunctionType<'db> {
pub name: ast::name::Name,
/// Is this a function that we special-case somehow? If so, which one?
kind: FunctionKind,
known: Option<KnownFunction>,
definition: Definition<'db>,
@ -1202,11 +1318,10 @@ impl<'db> FunctionType<'db> {
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash)]
pub enum FunctionKind {
/// Just a normal function for which we have no particular special casing
#[default]
Ordinary,
/// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might
/// have special behavior.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum KnownFunction {
/// `builtins.reveal_type`, `typing.reveal_type` or `typing_extensions.reveal_type`
RevealType,
}
@ -1220,9 +1335,18 @@ pub struct ClassType<'db> {
definition: Definition<'db>,
body_scope: ScopeId<'db>,
known: Option<KnownClass>,
}
impl<'db> ClassType<'db> {
pub fn is_known(self, db: &'db dyn Db, known_class: KnownClass) -> bool {
match self.known(db) {
Some(known) => known == known_class,
None => false,
}
}
/// Return true if this class is a standard library type with given module name and name.
pub(crate) fn is_stdlib_symbol(self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
name == self.name(db)

View file

@ -25,10 +25,12 @@
//! * No type in an intersection can be a supertype of any other type in the intersection (just
//! eliminate the supertype from the intersection).
//! * An intersection containing two non-overlapping types should simplify to [`Type::Never`].
use crate::types::{builtins_symbol_ty, IntersectionType, Type, UnionType};
use crate::types::{IntersectionType, Type, UnionType};
use crate::{Db, FxOrderSet};
use smallvec::SmallVec;
use super::KnownClass;
pub(crate) struct UnionBuilder<'db> {
elements: Vec<Type<'db>>,
db: &'db dyn Db,
@ -64,7 +66,7 @@ impl<'db> UnionBuilder<'db> {
let mut to_remove = SmallVec::<[usize; 2]>::new();
for (index, element) in self.elements.iter().enumerate() {
if Some(*element) == bool_pair {
to_add = builtins_symbol_ty(self.db, "bool");
to_add = KnownClass::Bool.to_class(self.db);
to_remove.push(index);
// The type we are adding is a BooleanLiteral, which doesn't have any
// subtypes. And we just found that the union already contained our
@ -300,7 +302,7 @@ mod tests {
use crate::db::tests::TestDb;
use crate::program::{Program, SearchPathSettings};
use crate::python_version::PythonVersion;
use crate::types::{builtins_symbol_ty, UnionBuilder};
use crate::types::{KnownClass, UnionBuilder};
use crate::ProgramSettings;
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
@ -360,7 +362,7 @@ mod tests {
#[test]
fn build_union_bool() {
let db = setup_db();
let bool_ty = builtins_symbol_ty(&db, "bool");
let bool_ty = KnownClass::Bool.to_class(&db);
let t0 = Type::BooleanLiteral(true);
let t1 = Type::BooleanLiteral(true);
@ -389,7 +391,7 @@ mod tests {
#[test]
fn build_union_simplify_subtype() {
let db = setup_db();
let t0 = Type::builtin_str_instance(&db);
let t0 = KnownClass::Str.to_instance(&db);
let t1 = Type::LiteralString;
let u0 = UnionType::from_elements(&db, [t0, t1]);
let u1 = UnionType::from_elements(&db, [t1, t0]);
@ -401,7 +403,7 @@ mod tests {
#[test]
fn build_union_no_simplify_unknown() {
let db = setup_db();
let t0 = Type::builtin_str_instance(&db);
let t0 = KnownClass::Str.to_instance(&db);
let t1 = Type::Unknown;
let u0 = UnionType::from_elements(&db, [t0, t1]);
let u1 = UnionType::from_elements(&db, [t1, t0]);
@ -413,9 +415,9 @@ mod tests {
#[test]
fn build_union_subsume_multiple() {
let db = setup_db();
let str_ty = Type::builtin_str_instance(&db);
let int_ty = Type::builtin_int_instance(&db);
let object_ty = builtins_symbol_ty(&db, "object").to_instance(&db);
let str_ty = KnownClass::Str.to_instance(&db);
let int_ty = KnownClass::Int.to_instance(&db);
let object_ty = KnownClass::Object.to_instance(&db);
let unknown_ty = Type::Unknown;
let u0 = UnionType::from_elements(&db, [str_ty, unknown_ty, int_ty, object_ty]);

View file

@ -51,11 +51,13 @@ use crate::stdlib::builtins_module_scope;
use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics};
use crate::types::{
bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_ty, symbol_ty,
typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionKind, FunctionType,
typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionType, KnownFunction,
StringLiteralType, Truthiness, TupleType, Type, TypeArrayDisplay, UnionType,
};
use crate::Db;
use super::KnownClass;
/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope.
/// Use when checking a scope, or needing to provide a type for an arbitrary expression in the
/// scope.
@ -518,8 +520,8 @@ impl<'db> TypeInferenceBuilder<'db> {
match left {
Type::IntLiteral(_) => {}
Type::Instance(cls)
if cls.is_stdlib_symbol(self.db, "builtins", "float")
|| cls.is_stdlib_symbol(self.db, "builtins", "int") => {}
if cls.is_known(self.db, KnownClass::Float)
|| cls.is_known(self.db, KnownClass::Int) => {}
_ => return,
};
@ -749,8 +751,10 @@ impl<'db> TypeInferenceBuilder<'db> {
}
let function_kind = match &**name {
"reveal_type" if definition.is_typing_definition(self.db) => FunctionKind::RevealType,
_ => FunctionKind::Ordinary,
"reveal_type" if definition.is_typing_definition(self.db) => {
Some(KnownFunction::RevealType)
}
_ => None,
};
let function_ty = Type::Function(FunctionType::new(
self.db,
@ -861,11 +865,15 @@ impl<'db> TypeInferenceBuilder<'db> {
.node_scope(NodeWithScopeRef::Class(class))
.to_scope_id(self.db, self.file);
let maybe_known_class = file_to_module(self.db, body_scope.file(self.db))
.as_ref()
.and_then(|module| KnownClass::maybe_from_module(module, name.as_str()));
let class_ty = Type::Class(ClassType::new(
self.db,
name.id.clone(),
definition,
body_scope,
maybe_known_class,
));
self.add_declaration_with_binding(class.into(), definition, class_ty, class_ty);
@ -1708,8 +1716,8 @@ impl<'db> TypeInferenceBuilder<'db> {
ast::Number::Int(n) => n
.as_i64()
.map(Type::IntLiteral)
.unwrap_or_else(|| Type::builtin_int_instance(self.db)),
ast::Number::Float(_) => builtins_symbol_ty(self.db, "float").to_instance(self.db),
.unwrap_or_else(|| KnownClass::Int.to_instance(self.db)),
ast::Number::Float(_) => KnownClass::Float.to_instance(self.db),
ast::Number::Complex { .. } => {
builtins_symbol_ty(self.db, "complex").to_instance(self.db)
}
@ -1826,7 +1834,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
// TODO generic
builtins_symbol_ty(self.db, "list").to_instance(self.db)
KnownClass::List.to_instance(self.db)
}
fn infer_set_expression(&mut self, set: &ast::ExprSet) -> Type<'db> {
@ -1837,7 +1845,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
// TODO generic
builtins_symbol_ty(self.db, "set").to_instance(self.db)
KnownClass::Set.to_instance(self.db)
}
fn infer_dict_expression(&mut self, dict: &ast::ExprDict) -> Type<'db> {
@ -1849,7 +1857,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
// TODO generic
builtins_symbol_ty(self.db, "dict").to_instance(self.db)
KnownClass::Dict.to_instance(self.db)
}
/// Infer the type of the `iter` expression of the first comprehension.
@ -2347,31 +2355,31 @@ impl<'db> TypeInferenceBuilder<'db> {
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Add) => n
.checked_add(m)
.map(Type::IntLiteral)
.unwrap_or_else(|| Type::builtin_int_instance(self.db)),
.unwrap_or_else(|| KnownClass::Int.to_instance(self.db)),
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Sub) => n
.checked_sub(m)
.map(Type::IntLiteral)
.unwrap_or_else(|| Type::builtin_int_instance(self.db)),
.unwrap_or_else(|| KnownClass::Int.to_instance(self.db)),
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mult) => n
.checked_mul(m)
.map(Type::IntLiteral)
.unwrap_or_else(|| Type::builtin_int_instance(self.db)),
.unwrap_or_else(|| KnownClass::Int.to_instance(self.db)),
(Type::IntLiteral(_), Type::IntLiteral(_), ast::Operator::Div) => {
builtins_symbol_ty(self.db, "float").to_instance(self.db)
KnownClass::Float.to_instance(self.db)
}
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::FloorDiv) => n
.checked_div(m)
.map(Type::IntLiteral)
.unwrap_or_else(|| Type::builtin_int_instance(self.db)),
.unwrap_or_else(|| KnownClass::Int.to_instance(self.db)),
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mod) => n
.checked_rem(m)
.map(Type::IntLiteral)
.unwrap_or_else(|| Type::builtin_int_instance(self.db)),
.unwrap_or_else(|| KnownClass::Int.to_instance(self.db)),
(Type::BytesLiteral(lhs), Type::BytesLiteral(rhs), ast::Operator::Add) => {
Type::BytesLiteral(BytesLiteralType::new(
@ -2581,10 +2589,10 @@ impl<'db> TypeInferenceBuilder<'db> {
ast::CmpOp::In | ast::CmpOp::NotIn => None,
},
(Type::IntLiteral(_), Type::Instance(_)) => {
self.infer_binary_type_comparison(Type::builtin_int_instance(self.db), op, right)
self.infer_binary_type_comparison(KnownClass::Int.to_instance(self.db), op, right)
}
(Type::Instance(_), Type::IntLiteral(_)) => {
self.infer_binary_type_comparison(left, op, Type::builtin_int_instance(self.db))
self.infer_binary_type_comparison(left, op, KnownClass::Int.to_instance(self.db))
}
// Booleans are coded as integers (False = 0, True = 1)
(Type::IntLiteral(n), Type::BooleanLiteral(b)) => self.infer_binary_type_comparison(
@ -3124,7 +3132,7 @@ impl StringPartsCollector {
fn ty(self, db: &dyn Db) -> Type {
if self.expression {
Type::builtin_str_instance(db)
KnownClass::Str.to_instance(db)
} else if let Some(concatenated) = self.concatenated {
Type::StringLiteral(StringLiteralType::new(db, concatenated.into_boxed_str()))
} else {