diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 54df8499de..d388fa8b2f 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -521,6 +521,54 @@ impl<'db> Type<'db> { } } + /// Resolves the boolean value of a type. + /// + /// This is used to determine the value that would be returned + /// when `bool(x)` is called on an object `x`. + fn bool(&self, db: &'db dyn Db) -> Truthiness { + match self { + Type::Any | Type::Never | Type::Unknown | Type::Unbound => Truthiness::Ambiguous, + Type::None => Truthiness::AlwaysFalse, + Type::Function(_) | Type::RevealTypeFunction(_) => Truthiness::AlwaysTrue, + Type::Module(_) => Truthiness::AlwaysTrue, + Type::Class(_) => { + // TODO: lookup `__bool__` and `__len__` methods on the class's metaclass + // More info in https://docs.python.org/3/library/stdtypes.html#truth-value-testing + Truthiness::Ambiguous + } + Type::Instance(_) => { + // TODO: lookup `__bool__` and `__len__` methods on the instance's class + // More info in https://docs.python.org/3/library/stdtypes.html#truth-value-testing + Truthiness::Ambiguous + } + Type::Union(union) => { + let union_elements = union.elements(db); + let first_element_truthiness = union_elements[0].bool(db); + if first_element_truthiness.is_ambiguous() { + return Truthiness::Ambiguous; + } + if !union_elements + .iter() + .skip(1) + .all(|element| element.bool(db) == first_element_truthiness) + { + return Truthiness::Ambiguous; + } + first_element_truthiness + } + Type::Intersection(_) => { + // TODO + Truthiness::Ambiguous + } + Type::IntLiteral(num) => Truthiness::from(*num != 0), + Type::BooleanLiteral(bool) => Truthiness::from(*bool), + Type::StringLiteral(str) => Truthiness::from(!str.value(db).is_empty()), + Type::LiteralString => Truthiness::Ambiguous, + Type::BytesLiteral(bytes) => Truthiness::from(!bytes.value(db).is_empty()), + Type::Tuple(items) => Truthiness::from(!items.elements(db).is_empty()), + } + } + /// Return the type resulting from calling an object of this type. /// /// Returns `None` if `self` is not a callable type. @@ -873,6 +921,50 @@ impl<'db> IterationOutcome<'db> { } } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum Truthiness { + /// For an object `x`, `bool(x)` will always return `True` + AlwaysTrue, + /// For an object `x`, `bool(x)` will always return `False` + AlwaysFalse, + /// For an object `x`, `bool(x)` could return either `True` or `False` + Ambiguous, +} + +impl Truthiness { + const fn is_ambiguous(self) -> bool { + matches!(self, Truthiness::Ambiguous) + } + + #[allow(unused)] + const fn negate(self) -> Self { + match self { + Self::AlwaysTrue => Self::AlwaysFalse, + Self::AlwaysFalse => Self::AlwaysTrue, + Self::Ambiguous => Self::Ambiguous, + } + } + + #[allow(unused)] + fn into_type(self, db: &dyn Db) -> Type { + match self { + Self::AlwaysTrue => Type::BooleanLiteral(true), + Self::AlwaysFalse => Type::BooleanLiteral(false), + Self::Ambiguous => builtins_symbol_ty(db, "bool").to_instance(db), + } + } +} + +impl From for Truthiness { + fn from(value: bool) -> Self { + if value { + Truthiness::AlwaysTrue + } else { + Truthiness::AlwaysFalse + } + } +} + #[salsa::interned] pub struct FunctionType<'db> { /// name of the function at definition @@ -1075,7 +1167,10 @@ pub struct TupleType<'db> { #[cfg(test)] mod tests { - use super::{builtins_symbol_ty, BytesLiteralType, StringLiteralType, Type, UnionType}; + use super::{ + builtins_symbol_ty, BytesLiteralType, StringLiteralType, Truthiness, TupleType, Type, + UnionType, + }; use crate::db::tests::TestDb; use crate::program::{Program, SearchPathSettings}; use crate::python_version::PythonVersion; @@ -1116,6 +1211,7 @@ mod tests { BytesLiteral(&'static str), BuiltinInstance(&'static str), Union(Vec), + Tuple(Vec), } impl Ty { @@ -1136,6 +1232,10 @@ mod tests { Ty::Union(tys) => { UnionType::from_elements(db, tys.into_iter().map(|ty| ty.into_type(db))) } + Ty::Tuple(tys) => { + let elements = tys.into_iter().map(|ty| ty.into_type(db)).collect(); + Type::Tuple(TupleType::new(db, elements)) + } } } } @@ -1205,4 +1305,32 @@ mod tests { assert!(from.into_type(&db).is_equivalent_to(&db, to.into_type(&db))); } + + #[test_case(Ty::IntLiteral(1); "is_int_literal_truthy")] + #[test_case(Ty::IntLiteral(-1))] + #[test_case(Ty::StringLiteral("foo"))] + #[test_case(Ty::Tuple(vec![Ty::IntLiteral(0)]))] + #[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]))] + fn is_truthy(ty: Ty) { + let db = setup_db(); + assert_eq!(ty.into_type(&db).bool(&db), Truthiness::AlwaysTrue); + } + + #[test_case(Ty::Tuple(vec![]))] + #[test_case(Ty::IntLiteral(0))] + #[test_case(Ty::StringLiteral(""))] + #[test_case(Ty::Union(vec![Ty::IntLiteral(0), Ty::IntLiteral(0)]))] + fn is_falsy(ty: Ty) { + let db = setup_db(); + assert_eq!(ty.into_type(&db).bool(&db), Truthiness::AlwaysFalse); + } + + #[test_case(Ty::BuiltinInstance("str"))] + #[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(0)]))] + #[test_case(Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::IntLiteral(0)]))] + #[test_case(Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::IntLiteral(1)]))] + fn boolean_value_is_unknown(ty: Ty) { + let db = setup_db(); + assert_eq!(ty.into_type(&db).bool(&db), Truthiness::Ambiguous); + } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index de3384e4e5..d13fccefaf 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -28,14 +28,13 @@ //! definitions once the rest of the types in the scope have been inferred. use std::num::NonZeroU32; -use rustc_hash::FxHashMap; -use salsa; -use salsa::plumbing::AsId; - use ruff_db::files::File; use ruff_db::parsed::parsed_module; use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, UnaryOp}; use ruff_text_size::Ranged; +use rustc_hash::FxHashMap; +use salsa; +use salsa::plumbing::AsId; use crate::module_name::ModuleName; use crate::module_resolver::{file_to_module, resolve_module}; @@ -52,7 +51,7 @@ use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics}; use crate::types::{ bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_ty, symbol_ty, typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionType, StringLiteralType, - TupleType, Type, TypeArrayDisplay, UnionType, + Truthiness, TupleType, Type, TypeArrayDisplay, UnionType, }; use crate::Db; @@ -2318,16 +2317,35 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_boolean_expression(&mut self, bool_op: &ast::ExprBoolOp) -> Type<'db> { let ast::ExprBoolOp { range: _, - op: _, + op, values, } = bool_op; - - for value in values { - self.infer_expression(value); - } - - // TODO resolve bool op - Type::Unknown + let mut done = false; + UnionType::from_elements( + self.db, + values.iter().enumerate().map(|(i, value)| { + // We need to infer the type of every expression (that's an invariant maintained by + // type inference), even if we can short-circuit boolean evaluation of some of + // those types. + let value_ty = self.infer_expression(value); + if done { + Type::Never + } else { + let is_last = i == values.len() - 1; + match (value_ty.bool(self.db), is_last, op) { + (Truthiness::Ambiguous, _, _) => value_ty, + (Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never, + (Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never, + (Truthiness::AlwaysFalse, _, ast::BoolOp::And) + | (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => { + done = true; + value_ty + } + (_, true, _) => value_ty, + } + } + }), + ) } fn infer_compare_expression(&mut self, compare: &ast::ExprCompare) -> Type<'db> { @@ -6048,4 +6066,96 @@ mod tests { ); Ok(()) } + + #[test] + fn boolean_or_expression() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + def foo() -> str: + pass + + a = True or False + b = 'x' or 'y' or 'z' + c = '' or 'y' or 'z' + d = False or 'z' + e = False or True + f = False or False + g = foo() or False + h = foo() or True + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "Literal[True]"); + assert_public_ty(&db, "/src/a.py", "b", r#"Literal["x"]"#); + assert_public_ty(&db, "/src/a.py", "c", r#"Literal["y"]"#); + assert_public_ty(&db, "/src/a.py", "d", r#"Literal["z"]"#); + assert_public_ty(&db, "/src/a.py", "e", "Literal[True]"); + assert_public_ty(&db, "/src/a.py", "f", "Literal[False]"); + assert_public_ty(&db, "/src/a.py", "g", "str | Literal[False]"); + assert_public_ty(&db, "/src/a.py", "h", "str | Literal[True]"); + + Ok(()) + } + + #[test] + fn boolean_and_expression() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + def foo() -> str: + pass + + a = True and False + b = False and True + c = foo() and False + d = foo() and True + e = 'x' and 'y' and 'z' + f = 'x' and 'y' and '' + g = '' and 'y' + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "Literal[False]"); + assert_public_ty(&db, "/src/a.py", "b", "Literal[False]"); + assert_public_ty(&db, "/src/a.py", "c", "str | Literal[False]"); + assert_public_ty(&db, "/src/a.py", "d", "str | Literal[True]"); + assert_public_ty(&db, "/src/a.py", "e", r#"Literal["z"]"#); + assert_public_ty(&db, "/src/a.py", "f", r#"Literal[""]"#); + assert_public_ty(&db, "/src/a.py", "g", r#"Literal[""]"#); + Ok(()) + } + + #[test] + fn boolean_complex_expression() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + r#" + def foo() -> str: + pass + + a = "x" and "y" or "z" + b = "x" or "y" and "z" + c = "" and "y" or "z" + d = "" or "y" and "z" + e = "x" and "y" or "" + f = "x" or "y" and "" + + "#, + )?; + + assert_public_ty(&db, "/src/a.py", "a", r#"Literal["y"]"#); + assert_public_ty(&db, "/src/a.py", "b", r#"Literal["x"]"#); + assert_public_ty(&db, "/src/a.py", "c", r#"Literal["z"]"#); + assert_public_ty(&db, "/src/a.py", "d", r#"Literal["z"]"#); + assert_public_ty(&db, "/src/a.py", "e", r#"Literal["y"]"#); + assert_public_ty(&db, "/src/a.py", "f", r#"Literal["x"]"#); + Ok(()) + } }