[red-knot] Type narrowing inside boolean expressions (#13970)

## Summary

This PR adds type narrowing in `and` and `or` expressions, for example:

```py
class A: ...

x: A | None = A() if bool_instance() else None

isinstance(x, A) or reveal_type(x)  # revealed: None
``` 

## Test Plan
New mdtests 😍
This commit is contained in:
TomerBin 2024-10-29 03:17:48 +02:00 committed by GitHub
parent ec6208e51b
commit 9a0dade925
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 124 additions and 18 deletions

View file

@ -0,0 +1,93 @@
# Narrowing in boolean expressions
In `or` expressions, the right-hand side is evaluated only if the left-hand side is **falsy**.
So when the right-hand side is evaluated, we know the left side has failed.
Similarly, in `and` expressions, the right-hand side is evaluated only if the left-hand side is **truthy**.
So when the right-hand side is evaluated, we know the left side has succeeded.
## Narrowing in `or`
```py
def bool_instance() -> bool:
return True
class A: ...
x: A | None = A() if bool_instance() else None
isinstance(x, A) or reveal_type(x) # revealed: None
x is None or reveal_type(x) # revealed: A
reveal_type(x) # revealed: A | None
```
## Narrowing in `and`
```py
def bool_instance() -> bool:
return True
class A: ...
x: A | None = A() if bool_instance() else None
isinstance(x, A) and reveal_type(x) # revealed: A
x is None and reveal_type(x) # revealed: None
reveal_type(x) # revealed: A | None
```
## Multiple `and` arms
```py
def bool_instance() -> bool:
return True
class A: ...
x: A | None = A() if bool_instance() else None
bool_instance() and isinstance(x, A) and reveal_type(x) # revealed: A
isinstance(x, A) and bool_instance() and reveal_type(x) # revealed: A
reveal_type(x) and isinstance(x, A) and bool_instance() # revealed: A | None
```
## Multiple `or` arms
```py
def bool_instance() -> bool:
return True
class A: ...
x: A | None = A() if bool_instance() else None
bool_instance() or isinstance(x, A) or reveal_type(x) # revealed: None
isinstance(x, A) or bool_instance() or reveal_type(x) # revealed: None
reveal_type(x) or isinstance(x, A) or bool_instance() # revealed: A | None
```
## Multiple predicates
```py
def bool_instance() -> bool:
return True
class A: ...
x: A | None | Literal[1] = A() if bool_instance() else None if bool_instance() else 1
x is None or isinstance(x, A) or reveal_type(x) # revealed: Literal[1]
```
## Mix of `and` and `or`
```py
def bool_instance() -> bool:
return True
class A: ...
x: A | None | Literal[1] = A() if bool_instance() else None if bool_instance() else 1
isinstance(x, A) or x is not None and reveal_type(x) # revealed: Literal[1]
```

View file

@ -9,7 +9,7 @@ use ruff_index::IndexVec;
use ruff_python_ast as ast;
use ruff_python_ast::name::Name;
use ruff_python_ast::visitor::{walk_expr, walk_pattern, walk_stmt, Visitor};
use ruff_python_ast::AnyParameterRef;
use ruff_python_ast::{AnyParameterRef, BoolOp, Expr};
use crate::ast_node_ref::AstNodeRef;
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
@ -243,18 +243,25 @@ impl<'db> SemanticIndexBuilder<'db> {
definition
}
fn add_expression_constraint(&mut self, constraint_node: &ast::Expr) -> Constraint<'db> {
let expression = self.add_standalone_expression(constraint_node);
let constraint = Constraint {
node: ConstraintNode::Expression(expression),
is_positive: true,
};
self.current_use_def_map_mut().record_constraint(constraint);
fn record_expression_constraint(&mut self, constraint_node: &ast::Expr) -> Constraint<'db> {
let constraint = self.build_constraint(constraint_node);
self.record_constraint(constraint);
constraint
}
fn add_negated_constraint(&mut self, constraint: Constraint<'db>) {
fn record_constraint(&mut self, constraint: Constraint<'db>) {
self.current_use_def_map_mut().record_constraint(constraint);
}
fn build_constraint(&mut self, constraint_node: &Expr) -> Constraint<'db> {
let expression = self.add_standalone_expression(constraint_node);
Constraint {
node: ConstraintNode::Expression(expression),
is_positive: true,
}
}
fn record_negated_constraint(&mut self, constraint: Constraint<'db>) {
self.current_use_def_map_mut()
.record_constraint(Constraint {
node: constraint.node,
@ -653,7 +660,7 @@ where
ast::Stmt::If(node) => {
self.visit_expr(&node.test);
let pre_if = self.flow_snapshot();
let constraint = self.add_expression_constraint(&node.test);
let constraint = self.record_expression_constraint(&node.test);
let mut constraints = vec![constraint];
self.visit_body(&node.body);
let mut post_clauses: Vec<FlowSnapshot> = vec![];
@ -665,11 +672,11 @@ where
// taken, so the block entry state is always `pre_if`
self.flow_restore(pre_if.clone());
for constraint in &constraints {
self.add_negated_constraint(*constraint);
self.record_negated_constraint(*constraint);
}
if let Some(elif_test) = &clause.test {
self.visit_expr(elif_test);
constraints.push(self.add_expression_constraint(elif_test));
constraints.push(self.record_expression_constraint(elif_test));
}
self.visit_body(&clause.body);
}
@ -1109,7 +1116,7 @@ where
ast::Expr::BoolOp(ast::ExprBoolOp {
values,
range: _,
op: _,
op,
}) => {
// TODO detect statically known truthy or falsy values (via type inference, not naive
// AST inspection, so we can't simplify here, need to record test expression for
@ -1117,11 +1124,17 @@ where
let mut snapshots = vec![];
for (index, value) in values.iter().enumerate() {
// The first item of BoolOp is always evaluated
if index > 0 {
snapshots.push(self.flow_snapshot());
}
self.visit_expr(value);
// In the last value we don't need to take a snapshot nor add a constraint
if index < values.len() - 1 {
// Snapshot is taken after visiting the expression but before adding the constraint.
snapshots.push(self.flow_snapshot());
let constraint = self.build_constraint(value);
match op {
BoolOp::And => self.record_constraint(constraint),
BoolOp::Or => self.record_negated_constraint(constraint),
}
}
}
for snapshot in snapshots {
self.flow_merge(snapshot);