[red-knot] Type narrowing for isinstance checks (#13894)

## Summary

Add type narrowing for `isinstance(object, classinfo)` [1] checks:
```py
x = 1 if flag else "a"

if isinstance(x, int):
    reveal_type(x)  # revealed: Literal[1]
```

closes #13893

[1] https://docs.python.org/3/library/functions.html#isinstance

## Test Plan

New Markdown-based tests in `narrow/isinstance.md`.

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
David Peter 2024-10-23 20:51:33 +02:00 committed by GitHub
parent 72c18c8225
commit 2c57c2dc8a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 241 additions and 12 deletions

View file

@ -47,7 +47,13 @@ impl<'db> Definition<'db> {
self.kind(db).category().is_binding()
}
/// Return true if this is a symbol was defined in the `typing` or `typing_extensions` modules
pub(crate) fn is_builtin_definition(self, db: &'db dyn Db) -> bool {
file_to_module(db, self.file(db)).is_some_and(|module| {
module.search_path().is_standard_library() && matches!(&**module.name(), "builtins")
})
}
/// Return true if this symbol was defined in the `typing` or `typing_extensions` modules
pub(crate) fn is_typing_definition(self, db: &'db dyn Db) -> bool {
file_to_module(db, self.file(db)).is_some_and(|module| {
module.search_path().is_standard_library()

View file

@ -868,13 +868,16 @@ impl<'db> Type<'db> {
fn call(self, db: &'db dyn Db, arg_types: &[Type<'db>]) -> CallOutcome<'db> {
match self {
// TODO validate typed call arguments vs callable signature
Type::FunctionLiteral(function_type) => match function_type.known(db) {
None => CallOutcome::callable(function_type.return_type(db)),
Some(KnownFunction::RevealType) => CallOutcome::revealed(
function_type.return_type(db),
*arg_types.first().unwrap_or(&Type::Unknown),
),
},
Type::FunctionLiteral(function_type) => {
if function_type.is_known(db, KnownFunction::RevealType) {
CallOutcome::revealed(
function_type.return_type(db),
*arg_types.first().unwrap_or(&Type::Unknown),
)
} else {
CallOutcome::callable(function_type.return_type(db))
}
}
// TODO annotated return type on `__new__` or metaclass `__call__`
Type::ClassLiteral(class) => {
@ -1595,6 +1598,10 @@ impl<'db> FunctionType<'db> {
})
.unwrap_or(Type::Unknown)
}
pub fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {
self.known(db) == Some(known_function)
}
}
/// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might
@ -1603,6 +1610,8 @@ impl<'db> FunctionType<'db> {
pub enum KnownFunction {
/// `builtins.reveal_type`, `typing.reveal_type` or `typing_extensions.reveal_type`
RevealType,
/// `builtins.isinstance`
IsInstance,
}
#[salsa::interned]

View file

@ -779,6 +779,9 @@ impl<'db> TypeInferenceBuilder<'db> {
"reveal_type" if definition.is_typing_definition(self.db) => {
Some(KnownFunction::RevealType)
}
"isinstance" if definition.is_builtin_definition(self.db) => {
Some(KnownFunction::IsInstance)
}
_ => None,
};
let function_ty = Type::FunctionLiteral(FunctionType::new(

View file

@ -4,7 +4,9 @@ use crate::semantic_index::definition::Definition;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable};
use crate::semantic_index::symbol_table;
use crate::types::{infer_expression_types, IntersectionBuilder, Type};
use crate::types::{
infer_expression_types, IntersectionBuilder, KnownFunction, Type, UnionBuilder,
};
use crate::Db;
use itertools::Itertools;
use ruff_python_ast as ast;
@ -60,6 +62,28 @@ fn all_narrowing_constraints_for_expression<'db>(
NarrowingConstraintsBuilder::new(db, Constraint::Expression(expression)).finish()
}
/// Generate a constraint from the *type* of the second argument of an `isinstance` call.
///
/// Example: for `isinstance(…, str)`, we would infer `Type::ClassLiteral(str)` from the
/// second argument, but we need to generate a `Type::Instance(str)` constraint that can
/// be used to narrow down the type of the first argument.
fn generate_isinstance_constraint<'db>(
db: &'db dyn Db,
classinfo: &Type<'db>,
) -> Option<Type<'db>> {
match classinfo {
Type::ClassLiteral(class) => Some(Type::Instance(*class)),
Type::Tuple(tuple) => {
let mut builder = UnionBuilder::new(db);
for element in tuple.elements(db) {
builder = builder.add(generate_isinstance_constraint(db, element)?);
}
Some(builder.build())
}
_ => None,
}
}
type NarrowingConstraints<'db> = FxHashMap<ScopedSymbolId, Type<'db>>;
struct NarrowingConstraintsBuilder<'db> {
@ -88,10 +112,15 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
}
fn evaluate_expression_constraint(&mut self, expression: Expression<'db>) {
if let ast::Expr::Compare(expr_compare) = expression.node_ref(self.db).node() {
self.add_expr_compare(expr_compare, expression);
match expression.node_ref(self.db).node() {
ast::Expr::Compare(expr_compare) => {
self.add_expr_compare(expr_compare, expression);
}
ast::Expr::Call(expr_call) => {
self.add_expr_call(expr_call, expression);
}
_ => {} // TODO other test expression kinds
}
// TODO other test expression kinds
}
fn evaluate_pattern_constraint(&mut self, pattern: PatternConstraint<'db>) {
@ -194,6 +223,33 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
}
}
fn add_expr_call(&mut self, expr_call: &ast::ExprCall, expression: Expression<'db>) {
let scope = self.scope();
let inference = infer_expression_types(self.db, expression);
if let Some(func_type) = inference
.expression_ty(expr_call.func.scoped_ast_id(self.db, scope))
.into_function_literal_type()
{
if func_type.is_known(self.db, KnownFunction::IsInstance)
&& expr_call.arguments.keywords.is_empty()
{
if let [ast::Expr::Name(ast::ExprName { id, .. }), rhs] = &*expr_call.arguments.args
{
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let rhs_type = inference.expression_ty(rhs.scoped_ast_id(self.db, scope));
// TODO: add support for PEP 604 union types on the right hand side:
// isinstance(x, str | (int | float))
if let Some(constraint) = generate_isinstance_constraint(self.db, &rhs_type) {
self.constraints.insert(symbol, constraint);
}
}
}
}
}
fn add_match_pattern_singleton(
&mut self,
subject: &ast::Expr,