mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-30 05:45:24 +00:00
[red-knot] Add Type::bool
and boolean expression inference (#13449)
This commit is contained in:
parent
03503f7f56
commit
be1d5e3368
2 changed files with 252 additions and 14 deletions
|
@ -521,6 +521,54 @@ impl<'db> Type<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Resolves the boolean value of a type.
|
||||
///
|
||||
/// This is used to determine the value that would be returned
|
||||
/// when `bool(x)` is called on an object `x`.
|
||||
fn bool(&self, db: &'db dyn Db) -> Truthiness {
|
||||
match self {
|
||||
Type::Any | Type::Never | Type::Unknown | Type::Unbound => Truthiness::Ambiguous,
|
||||
Type::None => Truthiness::AlwaysFalse,
|
||||
Type::Function(_) | Type::RevealTypeFunction(_) => Truthiness::AlwaysTrue,
|
||||
Type::Module(_) => Truthiness::AlwaysTrue,
|
||||
Type::Class(_) => {
|
||||
// TODO: lookup `__bool__` and `__len__` methods on the class's metaclass
|
||||
// More info in https://docs.python.org/3/library/stdtypes.html#truth-value-testing
|
||||
Truthiness::Ambiguous
|
||||
}
|
||||
Type::Instance(_) => {
|
||||
// TODO: lookup `__bool__` and `__len__` methods on the instance's class
|
||||
// More info in https://docs.python.org/3/library/stdtypes.html#truth-value-testing
|
||||
Truthiness::Ambiguous
|
||||
}
|
||||
Type::Union(union) => {
|
||||
let union_elements = union.elements(db);
|
||||
let first_element_truthiness = union_elements[0].bool(db);
|
||||
if first_element_truthiness.is_ambiguous() {
|
||||
return Truthiness::Ambiguous;
|
||||
}
|
||||
if !union_elements
|
||||
.iter()
|
||||
.skip(1)
|
||||
.all(|element| element.bool(db) == first_element_truthiness)
|
||||
{
|
||||
return Truthiness::Ambiguous;
|
||||
}
|
||||
first_element_truthiness
|
||||
}
|
||||
Type::Intersection(_) => {
|
||||
// TODO
|
||||
Truthiness::Ambiguous
|
||||
}
|
||||
Type::IntLiteral(num) => Truthiness::from(*num != 0),
|
||||
Type::BooleanLiteral(bool) => Truthiness::from(*bool),
|
||||
Type::StringLiteral(str) => Truthiness::from(!str.value(db).is_empty()),
|
||||
Type::LiteralString => Truthiness::Ambiguous,
|
||||
Type::BytesLiteral(bytes) => Truthiness::from(!bytes.value(db).is_empty()),
|
||||
Type::Tuple(items) => Truthiness::from(!items.elements(db).is_empty()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the type resulting from calling an object of this type.
|
||||
///
|
||||
/// Returns `None` if `self` is not a callable type.
|
||||
|
@ -873,6 +921,50 @@ impl<'db> IterationOutcome<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
enum Truthiness {
|
||||
/// For an object `x`, `bool(x)` will always return `True`
|
||||
AlwaysTrue,
|
||||
/// For an object `x`, `bool(x)` will always return `False`
|
||||
AlwaysFalse,
|
||||
/// For an object `x`, `bool(x)` could return either `True` or `False`
|
||||
Ambiguous,
|
||||
}
|
||||
|
||||
impl Truthiness {
|
||||
const fn is_ambiguous(self) -> bool {
|
||||
matches!(self, Truthiness::Ambiguous)
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
const fn negate(self) -> Self {
|
||||
match self {
|
||||
Self::AlwaysTrue => Self::AlwaysFalse,
|
||||
Self::AlwaysFalse => Self::AlwaysTrue,
|
||||
Self::Ambiguous => Self::Ambiguous,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn into_type(self, db: &dyn Db) -> Type {
|
||||
match self {
|
||||
Self::AlwaysTrue => Type::BooleanLiteral(true),
|
||||
Self::AlwaysFalse => Type::BooleanLiteral(false),
|
||||
Self::Ambiguous => builtins_symbol_ty(db, "bool").to_instance(db),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<bool> for Truthiness {
|
||||
fn from(value: bool) -> Self {
|
||||
if value {
|
||||
Truthiness::AlwaysTrue
|
||||
} else {
|
||||
Truthiness::AlwaysFalse
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[salsa::interned]
|
||||
pub struct FunctionType<'db> {
|
||||
/// name of the function at definition
|
||||
|
@ -1075,7 +1167,10 @@ pub struct TupleType<'db> {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{builtins_symbol_ty, BytesLiteralType, StringLiteralType, Type, UnionType};
|
||||
use super::{
|
||||
builtins_symbol_ty, BytesLiteralType, StringLiteralType, Truthiness, TupleType, Type,
|
||||
UnionType,
|
||||
};
|
||||
use crate::db::tests::TestDb;
|
||||
use crate::program::{Program, SearchPathSettings};
|
||||
use crate::python_version::PythonVersion;
|
||||
|
@ -1116,6 +1211,7 @@ mod tests {
|
|||
BytesLiteral(&'static str),
|
||||
BuiltinInstance(&'static str),
|
||||
Union(Vec<Ty>),
|
||||
Tuple(Vec<Ty>),
|
||||
}
|
||||
|
||||
impl Ty {
|
||||
|
@ -1136,6 +1232,10 @@ mod tests {
|
|||
Ty::Union(tys) => {
|
||||
UnionType::from_elements(db, tys.into_iter().map(|ty| ty.into_type(db)))
|
||||
}
|
||||
Ty::Tuple(tys) => {
|
||||
let elements = tys.into_iter().map(|ty| ty.into_type(db)).collect();
|
||||
Type::Tuple(TupleType::new(db, elements))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1205,4 +1305,32 @@ mod tests {
|
|||
|
||||
assert!(from.into_type(&db).is_equivalent_to(&db, to.into_type(&db)));
|
||||
}
|
||||
|
||||
#[test_case(Ty::IntLiteral(1); "is_int_literal_truthy")]
|
||||
#[test_case(Ty::IntLiteral(-1))]
|
||||
#[test_case(Ty::StringLiteral("foo"))]
|
||||
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(0)]))]
|
||||
#[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]))]
|
||||
fn is_truthy(ty: Ty) {
|
||||
let db = setup_db();
|
||||
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::AlwaysTrue);
|
||||
}
|
||||
|
||||
#[test_case(Ty::Tuple(vec![]))]
|
||||
#[test_case(Ty::IntLiteral(0))]
|
||||
#[test_case(Ty::StringLiteral(""))]
|
||||
#[test_case(Ty::Union(vec![Ty::IntLiteral(0), Ty::IntLiteral(0)]))]
|
||||
fn is_falsy(ty: Ty) {
|
||||
let db = setup_db();
|
||||
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::AlwaysFalse);
|
||||
}
|
||||
|
||||
#[test_case(Ty::BuiltinInstance("str"))]
|
||||
#[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(0)]))]
|
||||
#[test_case(Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::IntLiteral(0)]))]
|
||||
#[test_case(Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::IntLiteral(1)]))]
|
||||
fn boolean_value_is_unknown(ty: Ty) {
|
||||
let db = setup_db();
|
||||
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::Ambiguous);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,14 +28,13 @@
|
|||
//! definitions once the rest of the types in the scope have been inferred.
|
||||
use std::num::NonZeroU32;
|
||||
|
||||
use rustc_hash::FxHashMap;
|
||||
use salsa;
|
||||
use salsa::plumbing::AsId;
|
||||
|
||||
use ruff_db::files::File;
|
||||
use ruff_db::parsed::parsed_module;
|
||||
use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, UnaryOp};
|
||||
use ruff_text_size::Ranged;
|
||||
use rustc_hash::FxHashMap;
|
||||
use salsa;
|
||||
use salsa::plumbing::AsId;
|
||||
|
||||
use crate::module_name::ModuleName;
|
||||
use crate::module_resolver::{file_to_module, resolve_module};
|
||||
|
@ -52,7 +51,7 @@ use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics};
|
|||
use crate::types::{
|
||||
bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_ty, symbol_ty,
|
||||
typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionType, StringLiteralType,
|
||||
TupleType, Type, TypeArrayDisplay, UnionType,
|
||||
Truthiness, TupleType, Type, TypeArrayDisplay, UnionType,
|
||||
};
|
||||
use crate::Db;
|
||||
|
||||
|
@ -2318,16 +2317,35 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
fn infer_boolean_expression(&mut self, bool_op: &ast::ExprBoolOp) -> Type<'db> {
|
||||
let ast::ExprBoolOp {
|
||||
range: _,
|
||||
op: _,
|
||||
op,
|
||||
values,
|
||||
} = bool_op;
|
||||
|
||||
for value in values {
|
||||
self.infer_expression(value);
|
||||
}
|
||||
|
||||
// TODO resolve bool op
|
||||
Type::Unknown
|
||||
let mut done = false;
|
||||
UnionType::from_elements(
|
||||
self.db,
|
||||
values.iter().enumerate().map(|(i, value)| {
|
||||
// We need to infer the type of every expression (that's an invariant maintained by
|
||||
// type inference), even if we can short-circuit boolean evaluation of some of
|
||||
// those types.
|
||||
let value_ty = self.infer_expression(value);
|
||||
if done {
|
||||
Type::Never
|
||||
} else {
|
||||
let is_last = i == values.len() - 1;
|
||||
match (value_ty.bool(self.db), is_last, op) {
|
||||
(Truthiness::Ambiguous, _, _) => value_ty,
|
||||
(Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never,
|
||||
(Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never,
|
||||
(Truthiness::AlwaysFalse, _, ast::BoolOp::And)
|
||||
| (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => {
|
||||
done = true;
|
||||
value_ty
|
||||
}
|
||||
(_, true, _) => value_ty,
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn infer_compare_expression(&mut self, compare: &ast::ExprCompare) -> Type<'db> {
|
||||
|
@ -6048,4 +6066,96 @@ mod tests {
|
|||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boolean_or_expression() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"/src/a.py",
|
||||
"
|
||||
def foo() -> str:
|
||||
pass
|
||||
|
||||
a = True or False
|
||||
b = 'x' or 'y' or 'z'
|
||||
c = '' or 'y' or 'z'
|
||||
d = False or 'z'
|
||||
e = False or True
|
||||
f = False or False
|
||||
g = foo() or False
|
||||
h = foo() or True
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_public_ty(&db, "/src/a.py", "a", "Literal[True]");
|
||||
assert_public_ty(&db, "/src/a.py", "b", r#"Literal["x"]"#);
|
||||
assert_public_ty(&db, "/src/a.py", "c", r#"Literal["y"]"#);
|
||||
assert_public_ty(&db, "/src/a.py", "d", r#"Literal["z"]"#);
|
||||
assert_public_ty(&db, "/src/a.py", "e", "Literal[True]");
|
||||
assert_public_ty(&db, "/src/a.py", "f", "Literal[False]");
|
||||
assert_public_ty(&db, "/src/a.py", "g", "str | Literal[False]");
|
||||
assert_public_ty(&db, "/src/a.py", "h", "str | Literal[True]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boolean_and_expression() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"/src/a.py",
|
||||
"
|
||||
def foo() -> str:
|
||||
pass
|
||||
|
||||
a = True and False
|
||||
b = False and True
|
||||
c = foo() and False
|
||||
d = foo() and True
|
||||
e = 'x' and 'y' and 'z'
|
||||
f = 'x' and 'y' and ''
|
||||
g = '' and 'y'
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_public_ty(&db, "/src/a.py", "a", "Literal[False]");
|
||||
assert_public_ty(&db, "/src/a.py", "b", "Literal[False]");
|
||||
assert_public_ty(&db, "/src/a.py", "c", "str | Literal[False]");
|
||||
assert_public_ty(&db, "/src/a.py", "d", "str | Literal[True]");
|
||||
assert_public_ty(&db, "/src/a.py", "e", r#"Literal["z"]"#);
|
||||
assert_public_ty(&db, "/src/a.py", "f", r#"Literal[""]"#);
|
||||
assert_public_ty(&db, "/src/a.py", "g", r#"Literal[""]"#);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn boolean_complex_expression() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"/src/a.py",
|
||||
r#"
|
||||
def foo() -> str:
|
||||
pass
|
||||
|
||||
a = "x" and "y" or "z"
|
||||
b = "x" or "y" and "z"
|
||||
c = "" and "y" or "z"
|
||||
d = "" or "y" and "z"
|
||||
e = "x" and "y" or ""
|
||||
f = "x" or "y" and ""
|
||||
|
||||
"#,
|
||||
)?;
|
||||
|
||||
assert_public_ty(&db, "/src/a.py", "a", r#"Literal["y"]"#);
|
||||
assert_public_ty(&db, "/src/a.py", "b", r#"Literal["x"]"#);
|
||||
assert_public_ty(&db, "/src/a.py", "c", r#"Literal["z"]"#);
|
||||
assert_public_ty(&db, "/src/a.py", "d", r#"Literal["z"]"#);
|
||||
assert_public_ty(&db, "/src/a.py", "e", r#"Literal["y"]"#);
|
||||
assert_public_ty(&db, "/src/a.py", "f", r#"Literal["x"]"#);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue