mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-03 18:28:24 +00:00
[red-knot] Type narrowing inside boolean expressions (#13970)
## Summary
This PR adds type narrowing in `and` and `or` expressions, for example:
```py
class A: ...
x: A | None = A() if bool_instance() else None
isinstance(x, A) or reveal_type(x) # revealed: None
```
## Test Plan
New mdtests 😍
This commit is contained in:
parent
ec6208e51b
commit
9a0dade925
2 changed files with 124 additions and 18 deletions
|
@ -0,0 +1,93 @@
|
|||
# Narrowing in boolean expressions
|
||||
|
||||
In `or` expressions, the right-hand side is evaluated only if the left-hand side is **falsy**.
|
||||
So when the right-hand side is evaluated, we know the left side has failed.
|
||||
|
||||
Similarly, in `and` expressions, the right-hand side is evaluated only if the left-hand side is **truthy**.
|
||||
So when the right-hand side is evaluated, we know the left side has succeeded.
|
||||
|
||||
## Narrowing in `or`
|
||||
|
||||
```py
|
||||
def bool_instance() -> bool:
|
||||
return True
|
||||
|
||||
class A: ...
|
||||
|
||||
x: A | None = A() if bool_instance() else None
|
||||
|
||||
isinstance(x, A) or reveal_type(x) # revealed: None
|
||||
x is None or reveal_type(x) # revealed: A
|
||||
reveal_type(x) # revealed: A | None
|
||||
```
|
||||
|
||||
## Narrowing in `and`
|
||||
|
||||
```py
|
||||
def bool_instance() -> bool:
|
||||
return True
|
||||
|
||||
class A: ...
|
||||
|
||||
x: A | None = A() if bool_instance() else None
|
||||
|
||||
isinstance(x, A) and reveal_type(x) # revealed: A
|
||||
x is None and reveal_type(x) # revealed: None
|
||||
reveal_type(x) # revealed: A | None
|
||||
```
|
||||
|
||||
## Multiple `and` arms
|
||||
|
||||
```py
|
||||
def bool_instance() -> bool:
|
||||
return True
|
||||
|
||||
class A: ...
|
||||
|
||||
x: A | None = A() if bool_instance() else None
|
||||
|
||||
bool_instance() and isinstance(x, A) and reveal_type(x) # revealed: A
|
||||
isinstance(x, A) and bool_instance() and reveal_type(x) # revealed: A
|
||||
reveal_type(x) and isinstance(x, A) and bool_instance() # revealed: A | None
|
||||
```
|
||||
|
||||
## Multiple `or` arms
|
||||
|
||||
```py
|
||||
def bool_instance() -> bool:
|
||||
return True
|
||||
|
||||
class A: ...
|
||||
|
||||
x: A | None = A() if bool_instance() else None
|
||||
|
||||
bool_instance() or isinstance(x, A) or reveal_type(x) # revealed: None
|
||||
isinstance(x, A) or bool_instance() or reveal_type(x) # revealed: None
|
||||
reveal_type(x) or isinstance(x, A) or bool_instance() # revealed: A | None
|
||||
```
|
||||
|
||||
## Multiple predicates
|
||||
|
||||
```py
|
||||
def bool_instance() -> bool:
|
||||
return True
|
||||
|
||||
class A: ...
|
||||
|
||||
x: A | None | Literal[1] = A() if bool_instance() else None if bool_instance() else 1
|
||||
|
||||
x is None or isinstance(x, A) or reveal_type(x) # revealed: Literal[1]
|
||||
```
|
||||
|
||||
## Mix of `and` and `or`
|
||||
|
||||
```py
|
||||
def bool_instance() -> bool:
|
||||
return True
|
||||
|
||||
class A: ...
|
||||
|
||||
x: A | None | Literal[1] = A() if bool_instance() else None if bool_instance() else 1
|
||||
|
||||
isinstance(x, A) or x is not None and reveal_type(x) # revealed: Literal[1]
|
||||
```
|
|
@ -9,7 +9,7 @@ use ruff_index::IndexVec;
|
|||
use ruff_python_ast as ast;
|
||||
use ruff_python_ast::name::Name;
|
||||
use ruff_python_ast::visitor::{walk_expr, walk_pattern, walk_stmt, Visitor};
|
||||
use ruff_python_ast::AnyParameterRef;
|
||||
use ruff_python_ast::{AnyParameterRef, BoolOp, Expr};
|
||||
|
||||
use crate::ast_node_ref::AstNodeRef;
|
||||
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
|
||||
|
@ -243,18 +243,25 @@ impl<'db> SemanticIndexBuilder<'db> {
|
|||
definition
|
||||
}
|
||||
|
||||
fn add_expression_constraint(&mut self, constraint_node: &ast::Expr) -> Constraint<'db> {
|
||||
let expression = self.add_standalone_expression(constraint_node);
|
||||
let constraint = Constraint {
|
||||
node: ConstraintNode::Expression(expression),
|
||||
is_positive: true,
|
||||
};
|
||||
self.current_use_def_map_mut().record_constraint(constraint);
|
||||
|
||||
fn record_expression_constraint(&mut self, constraint_node: &ast::Expr) -> Constraint<'db> {
|
||||
let constraint = self.build_constraint(constraint_node);
|
||||
self.record_constraint(constraint);
|
||||
constraint
|
||||
}
|
||||
|
||||
fn add_negated_constraint(&mut self, constraint: Constraint<'db>) {
|
||||
fn record_constraint(&mut self, constraint: Constraint<'db>) {
|
||||
self.current_use_def_map_mut().record_constraint(constraint);
|
||||
}
|
||||
|
||||
fn build_constraint(&mut self, constraint_node: &Expr) -> Constraint<'db> {
|
||||
let expression = self.add_standalone_expression(constraint_node);
|
||||
Constraint {
|
||||
node: ConstraintNode::Expression(expression),
|
||||
is_positive: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn record_negated_constraint(&mut self, constraint: Constraint<'db>) {
|
||||
self.current_use_def_map_mut()
|
||||
.record_constraint(Constraint {
|
||||
node: constraint.node,
|
||||
|
@ -653,7 +660,7 @@ where
|
|||
ast::Stmt::If(node) => {
|
||||
self.visit_expr(&node.test);
|
||||
let pre_if = self.flow_snapshot();
|
||||
let constraint = self.add_expression_constraint(&node.test);
|
||||
let constraint = self.record_expression_constraint(&node.test);
|
||||
let mut constraints = vec![constraint];
|
||||
self.visit_body(&node.body);
|
||||
let mut post_clauses: Vec<FlowSnapshot> = vec![];
|
||||
|
@ -665,11 +672,11 @@ where
|
|||
// taken, so the block entry state is always `pre_if`
|
||||
self.flow_restore(pre_if.clone());
|
||||
for constraint in &constraints {
|
||||
self.add_negated_constraint(*constraint);
|
||||
self.record_negated_constraint(*constraint);
|
||||
}
|
||||
if let Some(elif_test) = &clause.test {
|
||||
self.visit_expr(elif_test);
|
||||
constraints.push(self.add_expression_constraint(elif_test));
|
||||
constraints.push(self.record_expression_constraint(elif_test));
|
||||
}
|
||||
self.visit_body(&clause.body);
|
||||
}
|
||||
|
@ -1109,7 +1116,7 @@ where
|
|||
ast::Expr::BoolOp(ast::ExprBoolOp {
|
||||
values,
|
||||
range: _,
|
||||
op: _,
|
||||
op,
|
||||
}) => {
|
||||
// TODO detect statically known truthy or falsy values (via type inference, not naive
|
||||
// AST inspection, so we can't simplify here, need to record test expression for
|
||||
|
@ -1117,11 +1124,17 @@ where
|
|||
let mut snapshots = vec![];
|
||||
|
||||
for (index, value) in values.iter().enumerate() {
|
||||
// The first item of BoolOp is always evaluated
|
||||
if index > 0 {
|
||||
snapshots.push(self.flow_snapshot());
|
||||
}
|
||||
self.visit_expr(value);
|
||||
// In the last value we don't need to take a snapshot nor add a constraint
|
||||
if index < values.len() - 1 {
|
||||
// Snapshot is taken after visiting the expression but before adding the constraint.
|
||||
snapshots.push(self.flow_snapshot());
|
||||
let constraint = self.build_constraint(value);
|
||||
match op {
|
||||
BoolOp::And => self.record_constraint(constraint),
|
||||
BoolOp::Or => self.record_negated_constraint(constraint),
|
||||
}
|
||||
}
|
||||
}
|
||||
for snapshot in snapshots {
|
||||
self.flow_merge(snapshot);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue