From ec72e675d9325d404864c0848e969f59a89090a7 Mon Sep 17 00:00:00 2001 From: TomerBin Date: Fri, 27 Sep 2024 22:11:55 +0300 Subject: [PATCH] Red Knot - Infer the return value of bool() (#13538) ## Summary Following #13449, this PR adds custom handling for the bool constructor, so when the input type has statically known truthiness value, it will be used as the return value of the bool function. For example, in the following snippet x will now be resolved to `Literal[True]` instead of `bool`. ```python x = bool(1) ``` ## Test Plan Some cargo tests were added. --- crates/red_knot_python_semantic/src/types.rs | 15 ++- .../src/types/infer.rs | 98 +++++++++++++++++-- 2 files changed, 106 insertions(+), 7 deletions(-) diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 92921eaa11..21278a4232 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -587,7 +587,20 @@ impl<'db> Type<'db> { ), // TODO annotated return type on `__new__` or metaclass `__call__` - Type::Class(class) => CallOutcome::callable(Type::Instance(class)), + Type::Class(class) => { + // If the class is the builtin-bool class (for example `bool(1)`), we try to return + // the specific truthiness value of the input arg, `Literal[True]` for the example above. + let is_bool = class.is_stdlib_symbol(db, "builtins", "bool"); + CallOutcome::callable(if is_bool { + arg_types + .first() + .unwrap_or(&Type::Unknown) + .bool(db) + .into_type(db) + } else { + Type::Instance(class) + }) + } // TODO: handle classes which implement the `__call__` protocol Type::Instance(_instance_ty) => CallOutcome::callable(Type::Unknown), diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 32484fe2bf..b2a0c3d768 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2711,12 +2711,6 @@ mod tests { use anyhow::Context; - use ruff_db::files::{system_path_to_file, File}; - use ruff_db::parsed::parsed_module; - use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; - use ruff_db::testing::assert_function_query_was_not_run; - use ruff_python_ast::name::Name; - use crate::db::tests::TestDb; use crate::program::{Program, SearchPathSettings}; use crate::python_version::PythonVersion; @@ -2728,6 +2722,11 @@ mod tests { check_types, global_symbol_ty, infer_definition_types, symbol_ty, TypeCheckDiagnostics, }; use crate::{HasTy, ProgramSettings, SemanticModel}; + use ruff_db::files::{system_path_to_file, File}; + use ruff_db::parsed::parsed_module; + use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; + use ruff_db::testing::assert_function_query_was_not_run; + use ruff_python_ast::name::Name; use super::TypeInferenceBuilder; @@ -6483,4 +6482,91 @@ mod tests { assert_public_ty(&db, "/src/a.py", "f", r#"Literal["x"]"#); Ok(()) } + + #[test] + fn bool_function_falsy_values() -> anyhow::Result<()> { + let mut db = setup_db(); + db.write_dedented( + "/src/a.py", + r#" + a = bool(0) + b = bool(()) + c = bool(None) + d = bool("") + e = bool(False) + "#, + )?; + assert_public_ty(&db, "/src/a.py", "a", "Literal[False]"); + assert_public_ty(&db, "/src/a.py", "b", "Literal[False]"); + assert_public_ty(&db, "/src/a.py", "c", "Literal[False]"); + assert_public_ty(&db, "/src/a.py", "d", "Literal[False]"); + assert_public_ty(&db, "/src/a.py", "e", "Literal[False]"); + Ok(()) + } + + #[test] + fn builtin_bool_function_detected() -> anyhow::Result<()> { + let mut db = setup_db(); + db.write_dedented( + "/src/a.py", + " + redefined_builtin_bool = bool + + def my_bool(x)-> bool: pass + ", + )?; + db.write_dedented( + "/src/b.py", + " + from a import redefined_builtin_bool, my_bool + a = redefined_builtin_bool(0) + b = my_bool(0) + ", + )?; + assert_public_ty(&db, "/src/b.py", "a", "Literal[False]"); + assert_public_ty(&db, "/src/b.py", "b", "bool"); + Ok(()) + } + + #[test] + fn bool_function_truthy_values() -> anyhow::Result<()> { + let mut db = setup_db(); + db.write_dedented( + "/src/a.py", + r#" + a = bool(1) + b = bool((0,)) + c = bool("NON EMPTY") + d = bool(True) + + def foo(): pass + e = bool(foo) + "#, + )?; + + assert_public_ty(&db, "/src/a.py", "a", "Literal[True]"); + assert_public_ty(&db, "/src/a.py", "b", "Literal[True]"); + assert_public_ty(&db, "/src/a.py", "c", "Literal[True]"); + assert_public_ty(&db, "/src/a.py", "d", "Literal[True]"); + assert_public_ty(&db, "/src/a.py", "e", "Literal[True]"); + Ok(()) + } + + #[test] + fn bool_function_ambiguous_values() -> anyhow::Result<()> { + let mut db = setup_db(); + db.write_dedented( + "/src/a.py", + " + a = bool([]) + b = bool({}) + c = bool(set()) + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "bool"); + assert_public_ty(&db, "/src/a.py", "b", "bool"); + assert_public_ty(&db, "/src/a.py", "c", "bool"); + Ok(()) + } }