[red-knot] Add Type::bool and boolean expression inference (#13449)

This commit is contained in:
TomerBin 2024-09-25 03:02:26 +03:00 committed by GitHub
parent 03503f7f56
commit be1d5e3368
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 252 additions and 14 deletions

View file

@ -521,6 +521,54 @@ impl<'db> Type<'db> {
}
}
/// Resolves the boolean value of a type.
///
/// This is used to determine the value that would be returned
/// when `bool(x)` is called on an object `x`.
fn bool(&self, db: &'db dyn Db) -> Truthiness {
match self {
Type::Any | Type::Never | Type::Unknown | Type::Unbound => Truthiness::Ambiguous,
Type::None => Truthiness::AlwaysFalse,
Type::Function(_) | Type::RevealTypeFunction(_) => Truthiness::AlwaysTrue,
Type::Module(_) => Truthiness::AlwaysTrue,
Type::Class(_) => {
// TODO: lookup `__bool__` and `__len__` methods on the class's metaclass
// More info in https://docs.python.org/3/library/stdtypes.html#truth-value-testing
Truthiness::Ambiguous
}
Type::Instance(_) => {
// TODO: lookup `__bool__` and `__len__` methods on the instance's class
// More info in https://docs.python.org/3/library/stdtypes.html#truth-value-testing
Truthiness::Ambiguous
}
Type::Union(union) => {
let union_elements = union.elements(db);
let first_element_truthiness = union_elements[0].bool(db);
if first_element_truthiness.is_ambiguous() {
return Truthiness::Ambiguous;
}
if !union_elements
.iter()
.skip(1)
.all(|element| element.bool(db) == first_element_truthiness)
{
return Truthiness::Ambiguous;
}
first_element_truthiness
}
Type::Intersection(_) => {
// TODO
Truthiness::Ambiguous
}
Type::IntLiteral(num) => Truthiness::from(*num != 0),
Type::BooleanLiteral(bool) => Truthiness::from(*bool),
Type::StringLiteral(str) => Truthiness::from(!str.value(db).is_empty()),
Type::LiteralString => Truthiness::Ambiguous,
Type::BytesLiteral(bytes) => Truthiness::from(!bytes.value(db).is_empty()),
Type::Tuple(items) => Truthiness::from(!items.elements(db).is_empty()),
}
}
/// Return the type resulting from calling an object of this type.
///
/// Returns `None` if `self` is not a callable type.
@ -873,6 +921,50 @@ impl<'db> IterationOutcome<'db> {
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum Truthiness {
/// For an object `x`, `bool(x)` will always return `True`
AlwaysTrue,
/// For an object `x`, `bool(x)` will always return `False`
AlwaysFalse,
/// For an object `x`, `bool(x)` could return either `True` or `False`
Ambiguous,
}
impl Truthiness {
const fn is_ambiguous(self) -> bool {
matches!(self, Truthiness::Ambiguous)
}
#[allow(unused)]
const fn negate(self) -> Self {
match self {
Self::AlwaysTrue => Self::AlwaysFalse,
Self::AlwaysFalse => Self::AlwaysTrue,
Self::Ambiguous => Self::Ambiguous,
}
}
#[allow(unused)]
fn into_type(self, db: &dyn Db) -> Type {
match self {
Self::AlwaysTrue => Type::BooleanLiteral(true),
Self::AlwaysFalse => Type::BooleanLiteral(false),
Self::Ambiguous => builtins_symbol_ty(db, "bool").to_instance(db),
}
}
}
impl From<bool> for Truthiness {
fn from(value: bool) -> Self {
if value {
Truthiness::AlwaysTrue
} else {
Truthiness::AlwaysFalse
}
}
}
#[salsa::interned]
pub struct FunctionType<'db> {
/// name of the function at definition
@ -1075,7 +1167,10 @@ pub struct TupleType<'db> {
#[cfg(test)]
mod tests {
use super::{builtins_symbol_ty, BytesLiteralType, StringLiteralType, Type, UnionType};
use super::{
builtins_symbol_ty, BytesLiteralType, StringLiteralType, Truthiness, TupleType, Type,
UnionType,
};
use crate::db::tests::TestDb;
use crate::program::{Program, SearchPathSettings};
use crate::python_version::PythonVersion;
@ -1116,6 +1211,7 @@ mod tests {
BytesLiteral(&'static str),
BuiltinInstance(&'static str),
Union(Vec<Ty>),
Tuple(Vec<Ty>),
}
impl Ty {
@ -1136,6 +1232,10 @@ mod tests {
Ty::Union(tys) => {
UnionType::from_elements(db, tys.into_iter().map(|ty| ty.into_type(db)))
}
Ty::Tuple(tys) => {
let elements = tys.into_iter().map(|ty| ty.into_type(db)).collect();
Type::Tuple(TupleType::new(db, elements))
}
}
}
}
@ -1205,4 +1305,32 @@ mod tests {
assert!(from.into_type(&db).is_equivalent_to(&db, to.into_type(&db)));
}
#[test_case(Ty::IntLiteral(1); "is_int_literal_truthy")]
#[test_case(Ty::IntLiteral(-1))]
#[test_case(Ty::StringLiteral("foo"))]
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(0)]))]
#[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]))]
fn is_truthy(ty: Ty) {
let db = setup_db();
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::AlwaysTrue);
}
#[test_case(Ty::Tuple(vec![]))]
#[test_case(Ty::IntLiteral(0))]
#[test_case(Ty::StringLiteral(""))]
#[test_case(Ty::Union(vec![Ty::IntLiteral(0), Ty::IntLiteral(0)]))]
fn is_falsy(ty: Ty) {
let db = setup_db();
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::AlwaysFalse);
}
#[test_case(Ty::BuiltinInstance("str"))]
#[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(0)]))]
#[test_case(Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::IntLiteral(0)]))]
#[test_case(Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::IntLiteral(1)]))]
fn boolean_value_is_unknown(ty: Ty) {
let db = setup_db();
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::Ambiguous);
}
}

View file

@ -28,14 +28,13 @@
//! definitions once the rest of the types in the scope have been inferred.
use std::num::NonZeroU32;
use rustc_hash::FxHashMap;
use salsa;
use salsa::plumbing::AsId;
use ruff_db::files::File;
use ruff_db::parsed::parsed_module;
use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, UnaryOp};
use ruff_text_size::Ranged;
use rustc_hash::FxHashMap;
use salsa;
use salsa::plumbing::AsId;
use crate::module_name::ModuleName;
use crate::module_resolver::{file_to_module, resolve_module};
@ -52,7 +51,7 @@ use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics};
use crate::types::{
bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_ty, symbol_ty,
typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionType, StringLiteralType,
TupleType, Type, TypeArrayDisplay, UnionType,
Truthiness, TupleType, Type, TypeArrayDisplay, UnionType,
};
use crate::Db;
@ -2318,16 +2317,35 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_boolean_expression(&mut self, bool_op: &ast::ExprBoolOp) -> Type<'db> {
let ast::ExprBoolOp {
range: _,
op: _,
op,
values,
} = bool_op;
for value in values {
self.infer_expression(value);
let mut done = false;
UnionType::from_elements(
self.db,
values.iter().enumerate().map(|(i, value)| {
// We need to infer the type of every expression (that's an invariant maintained by
// type inference), even if we can short-circuit boolean evaluation of some of
// those types.
let value_ty = self.infer_expression(value);
if done {
Type::Never
} else {
let is_last = i == values.len() - 1;
match (value_ty.bool(self.db), is_last, op) {
(Truthiness::Ambiguous, _, _) => value_ty,
(Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never,
(Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never,
(Truthiness::AlwaysFalse, _, ast::BoolOp::And)
| (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => {
done = true;
value_ty
}
// TODO resolve bool op
Type::Unknown
(_, true, _) => value_ty,
}
}
}),
)
}
fn infer_compare_expression(&mut self, compare: &ast::ExprCompare) -> Type<'db> {
@ -6048,4 +6066,96 @@ mod tests {
);
Ok(())
}
#[test]
fn boolean_or_expression() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
"
def foo() -> str:
pass
a = True or False
b = 'x' or 'y' or 'z'
c = '' or 'y' or 'z'
d = False or 'z'
e = False or True
f = False or False
g = foo() or False
h = foo() or True
",
)?;
assert_public_ty(&db, "/src/a.py", "a", "Literal[True]");
assert_public_ty(&db, "/src/a.py", "b", r#"Literal["x"]"#);
assert_public_ty(&db, "/src/a.py", "c", r#"Literal["y"]"#);
assert_public_ty(&db, "/src/a.py", "d", r#"Literal["z"]"#);
assert_public_ty(&db, "/src/a.py", "e", "Literal[True]");
assert_public_ty(&db, "/src/a.py", "f", "Literal[False]");
assert_public_ty(&db, "/src/a.py", "g", "str | Literal[False]");
assert_public_ty(&db, "/src/a.py", "h", "str | Literal[True]");
Ok(())
}
#[test]
fn boolean_and_expression() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
"
def foo() -> str:
pass
a = True and False
b = False and True
c = foo() and False
d = foo() and True
e = 'x' and 'y' and 'z'
f = 'x' and 'y' and ''
g = '' and 'y'
",
)?;
assert_public_ty(&db, "/src/a.py", "a", "Literal[False]");
assert_public_ty(&db, "/src/a.py", "b", "Literal[False]");
assert_public_ty(&db, "/src/a.py", "c", "str | Literal[False]");
assert_public_ty(&db, "/src/a.py", "d", "str | Literal[True]");
assert_public_ty(&db, "/src/a.py", "e", r#"Literal["z"]"#);
assert_public_ty(&db, "/src/a.py", "f", r#"Literal[""]"#);
assert_public_ty(&db, "/src/a.py", "g", r#"Literal[""]"#);
Ok(())
}
#[test]
fn boolean_complex_expression() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
r#"
def foo() -> str:
pass
a = "x" and "y" or "z"
b = "x" or "y" and "z"
c = "" and "y" or "z"
d = "" or "y" and "z"
e = "x" and "y" or ""
f = "x" or "y" and ""
"#,
)?;
assert_public_ty(&db, "/src/a.py", "a", r#"Literal["y"]"#);
assert_public_ty(&db, "/src/a.py", "b", r#"Literal["x"]"#);
assert_public_ty(&db, "/src/a.py", "c", r#"Literal["z"]"#);
assert_public_ty(&db, "/src/a.py", "d", r#"Literal["z"]"#);
assert_public_ty(&db, "/src/a.py", "e", r#"Literal["y"]"#);
assert_public_ty(&db, "/src/a.py", "f", r#"Literal["x"]"#);
Ok(())
}
}