[red-knot] Add Type::bool and boolean expression inference (#13449)

This commit is contained in:
TomerBin 2024-09-25 03:02:26 +03:00 committed by GitHub
parent 03503f7f56
commit be1d5e3368
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 252 additions and 14 deletions

View file

@ -521,6 +521,54 @@ impl<'db> Type<'db> {
} }
} }
/// Resolves the boolean value of a type.
///
/// This is used to determine the value that would be returned
/// when `bool(x)` is called on an object `x`.
fn bool(&self, db: &'db dyn Db) -> Truthiness {
match self {
Type::Any | Type::Never | Type::Unknown | Type::Unbound => Truthiness::Ambiguous,
Type::None => Truthiness::AlwaysFalse,
Type::Function(_) | Type::RevealTypeFunction(_) => Truthiness::AlwaysTrue,
Type::Module(_) => Truthiness::AlwaysTrue,
Type::Class(_) => {
// TODO: lookup `__bool__` and `__len__` methods on the class's metaclass
// More info in https://docs.python.org/3/library/stdtypes.html#truth-value-testing
Truthiness::Ambiguous
}
Type::Instance(_) => {
// TODO: lookup `__bool__` and `__len__` methods on the instance's class
// More info in https://docs.python.org/3/library/stdtypes.html#truth-value-testing
Truthiness::Ambiguous
}
Type::Union(union) => {
let union_elements = union.elements(db);
let first_element_truthiness = union_elements[0].bool(db);
if first_element_truthiness.is_ambiguous() {
return Truthiness::Ambiguous;
}
if !union_elements
.iter()
.skip(1)
.all(|element| element.bool(db) == first_element_truthiness)
{
return Truthiness::Ambiguous;
}
first_element_truthiness
}
Type::Intersection(_) => {
// TODO
Truthiness::Ambiguous
}
Type::IntLiteral(num) => Truthiness::from(*num != 0),
Type::BooleanLiteral(bool) => Truthiness::from(*bool),
Type::StringLiteral(str) => Truthiness::from(!str.value(db).is_empty()),
Type::LiteralString => Truthiness::Ambiguous,
Type::BytesLiteral(bytes) => Truthiness::from(!bytes.value(db).is_empty()),
Type::Tuple(items) => Truthiness::from(!items.elements(db).is_empty()),
}
}
/// Return the type resulting from calling an object of this type. /// Return the type resulting from calling an object of this type.
/// ///
/// Returns `None` if `self` is not a callable type. /// Returns `None` if `self` is not a callable type.
@ -873,6 +921,50 @@ impl<'db> IterationOutcome<'db> {
} }
} }
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum Truthiness {
/// For an object `x`, `bool(x)` will always return `True`
AlwaysTrue,
/// For an object `x`, `bool(x)` will always return `False`
AlwaysFalse,
/// For an object `x`, `bool(x)` could return either `True` or `False`
Ambiguous,
}
impl Truthiness {
const fn is_ambiguous(self) -> bool {
matches!(self, Truthiness::Ambiguous)
}
#[allow(unused)]
const fn negate(self) -> Self {
match self {
Self::AlwaysTrue => Self::AlwaysFalse,
Self::AlwaysFalse => Self::AlwaysTrue,
Self::Ambiguous => Self::Ambiguous,
}
}
#[allow(unused)]
fn into_type(self, db: &dyn Db) -> Type {
match self {
Self::AlwaysTrue => Type::BooleanLiteral(true),
Self::AlwaysFalse => Type::BooleanLiteral(false),
Self::Ambiguous => builtins_symbol_ty(db, "bool").to_instance(db),
}
}
}
impl From<bool> for Truthiness {
fn from(value: bool) -> Self {
if value {
Truthiness::AlwaysTrue
} else {
Truthiness::AlwaysFalse
}
}
}
#[salsa::interned] #[salsa::interned]
pub struct FunctionType<'db> { pub struct FunctionType<'db> {
/// name of the function at definition /// name of the function at definition
@ -1075,7 +1167,10 @@ pub struct TupleType<'db> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{builtins_symbol_ty, BytesLiteralType, StringLiteralType, Type, UnionType}; use super::{
builtins_symbol_ty, BytesLiteralType, StringLiteralType, Truthiness, TupleType, Type,
UnionType,
};
use crate::db::tests::TestDb; use crate::db::tests::TestDb;
use crate::program::{Program, SearchPathSettings}; use crate::program::{Program, SearchPathSettings};
use crate::python_version::PythonVersion; use crate::python_version::PythonVersion;
@ -1116,6 +1211,7 @@ mod tests {
BytesLiteral(&'static str), BytesLiteral(&'static str),
BuiltinInstance(&'static str), BuiltinInstance(&'static str),
Union(Vec<Ty>), Union(Vec<Ty>),
Tuple(Vec<Ty>),
} }
impl Ty { impl Ty {
@ -1136,6 +1232,10 @@ mod tests {
Ty::Union(tys) => { Ty::Union(tys) => {
UnionType::from_elements(db, tys.into_iter().map(|ty| ty.into_type(db))) UnionType::from_elements(db, tys.into_iter().map(|ty| ty.into_type(db)))
} }
Ty::Tuple(tys) => {
let elements = tys.into_iter().map(|ty| ty.into_type(db)).collect();
Type::Tuple(TupleType::new(db, elements))
}
} }
} }
} }
@ -1205,4 +1305,32 @@ mod tests {
assert!(from.into_type(&db).is_equivalent_to(&db, to.into_type(&db))); assert!(from.into_type(&db).is_equivalent_to(&db, to.into_type(&db)));
} }
#[test_case(Ty::IntLiteral(1); "is_int_literal_truthy")]
#[test_case(Ty::IntLiteral(-1))]
#[test_case(Ty::StringLiteral("foo"))]
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(0)]))]
#[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]))]
fn is_truthy(ty: Ty) {
let db = setup_db();
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::AlwaysTrue);
}
#[test_case(Ty::Tuple(vec![]))]
#[test_case(Ty::IntLiteral(0))]
#[test_case(Ty::StringLiteral(""))]
#[test_case(Ty::Union(vec![Ty::IntLiteral(0), Ty::IntLiteral(0)]))]
fn is_falsy(ty: Ty) {
let db = setup_db();
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::AlwaysFalse);
}
#[test_case(Ty::BuiltinInstance("str"))]
#[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(0)]))]
#[test_case(Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::IntLiteral(0)]))]
#[test_case(Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::IntLiteral(1)]))]
fn boolean_value_is_unknown(ty: Ty) {
let db = setup_db();
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::Ambiguous);
}
} }

View file

@ -28,14 +28,13 @@
//! definitions once the rest of the types in the scope have been inferred. //! definitions once the rest of the types in the scope have been inferred.
use std::num::NonZeroU32; use std::num::NonZeroU32;
use rustc_hash::FxHashMap;
use salsa;
use salsa::plumbing::AsId;
use ruff_db::files::File; use ruff_db::files::File;
use ruff_db::parsed::parsed_module; use ruff_db::parsed::parsed_module;
use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, UnaryOp}; use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, UnaryOp};
use ruff_text_size::Ranged; use ruff_text_size::Ranged;
use rustc_hash::FxHashMap;
use salsa;
use salsa::plumbing::AsId;
use crate::module_name::ModuleName; use crate::module_name::ModuleName;
use crate::module_resolver::{file_to_module, resolve_module}; use crate::module_resolver::{file_to_module, resolve_module};
@ -52,7 +51,7 @@ use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics};
use crate::types::{ use crate::types::{
bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_ty, symbol_ty, bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_ty, symbol_ty,
typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionType, StringLiteralType, typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionType, StringLiteralType,
TupleType, Type, TypeArrayDisplay, UnionType, Truthiness, TupleType, Type, TypeArrayDisplay, UnionType,
}; };
use crate::Db; use crate::Db;
@ -2318,16 +2317,35 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_boolean_expression(&mut self, bool_op: &ast::ExprBoolOp) -> Type<'db> { fn infer_boolean_expression(&mut self, bool_op: &ast::ExprBoolOp) -> Type<'db> {
let ast::ExprBoolOp { let ast::ExprBoolOp {
range: _, range: _,
op: _, op,
values, values,
} = bool_op; } = bool_op;
let mut done = false;
for value in values { UnionType::from_elements(
self.infer_expression(value); self.db,
values.iter().enumerate().map(|(i, value)| {
// We need to infer the type of every expression (that's an invariant maintained by
// type inference), even if we can short-circuit boolean evaluation of some of
// those types.
let value_ty = self.infer_expression(value);
if done {
Type::Never
} else {
let is_last = i == values.len() - 1;
match (value_ty.bool(self.db), is_last, op) {
(Truthiness::Ambiguous, _, _) => value_ty,
(Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never,
(Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never,
(Truthiness::AlwaysFalse, _, ast::BoolOp::And)
| (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => {
done = true;
value_ty
} }
(_, true, _) => value_ty,
// TODO resolve bool op }
Type::Unknown }
}),
)
} }
fn infer_compare_expression(&mut self, compare: &ast::ExprCompare) -> Type<'db> { fn infer_compare_expression(&mut self, compare: &ast::ExprCompare) -> Type<'db> {
@ -6048,4 +6066,96 @@ mod tests {
); );
Ok(()) Ok(())
} }
#[test]
fn boolean_or_expression() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
"
def foo() -> str:
pass
a = True or False
b = 'x' or 'y' or 'z'
c = '' or 'y' or 'z'
d = False or 'z'
e = False or True
f = False or False
g = foo() or False
h = foo() or True
",
)?;
assert_public_ty(&db, "/src/a.py", "a", "Literal[True]");
assert_public_ty(&db, "/src/a.py", "b", r#"Literal["x"]"#);
assert_public_ty(&db, "/src/a.py", "c", r#"Literal["y"]"#);
assert_public_ty(&db, "/src/a.py", "d", r#"Literal["z"]"#);
assert_public_ty(&db, "/src/a.py", "e", "Literal[True]");
assert_public_ty(&db, "/src/a.py", "f", "Literal[False]");
assert_public_ty(&db, "/src/a.py", "g", "str | Literal[False]");
assert_public_ty(&db, "/src/a.py", "h", "str | Literal[True]");
Ok(())
}
#[test]
fn boolean_and_expression() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
"
def foo() -> str:
pass
a = True and False
b = False and True
c = foo() and False
d = foo() and True
e = 'x' and 'y' and 'z'
f = 'x' and 'y' and ''
g = '' and 'y'
",
)?;
assert_public_ty(&db, "/src/a.py", "a", "Literal[False]");
assert_public_ty(&db, "/src/a.py", "b", "Literal[False]");
assert_public_ty(&db, "/src/a.py", "c", "str | Literal[False]");
assert_public_ty(&db, "/src/a.py", "d", "str | Literal[True]");
assert_public_ty(&db, "/src/a.py", "e", r#"Literal["z"]"#);
assert_public_ty(&db, "/src/a.py", "f", r#"Literal[""]"#);
assert_public_ty(&db, "/src/a.py", "g", r#"Literal[""]"#);
Ok(())
}
#[test]
fn boolean_complex_expression() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
r#"
def foo() -> str:
pass
a = "x" and "y" or "z"
b = "x" or "y" and "z"
c = "" and "y" or "z"
d = "" or "y" and "z"
e = "x" and "y" or ""
f = "x" or "y" and ""
"#,
)?;
assert_public_ty(&db, "/src/a.py", "a", r#"Literal["y"]"#);
assert_public_ty(&db, "/src/a.py", "b", r#"Literal["x"]"#);
assert_public_ty(&db, "/src/a.py", "c", r#"Literal["z"]"#);
assert_public_ty(&db, "/src/a.py", "d", r#"Literal["z"]"#);
assert_public_ty(&db, "/src/a.py", "e", r#"Literal["y"]"#);
assert_public_ty(&db, "/src/a.py", "f", r#"Literal["x"]"#);
Ok(())
}
} }