[ty] Better control flow for boolean expressions that are inside if (#18010)

## Summary
With this PR we now detect that x is always defined in `use`:
```py
if flag and (x := number):
    use(x)
```

When outside if, it's still detected as possibly not defined
```py
flag and (x := number)
# error: [possibly-unresolved-reference]
use(x)
```
In order to achieve that, I had to find a way to get access to the
flow-snapshots of the boolean expression when analyzing the flow of the
if statement. I did it by special casing the visitor of boolean
expression to return flow control information, exporting two snapshots -
`maybe_short_circuit` and `no_short_circuit`. When indexing
boolean expression itself we must assume all possible flows, but when
it's inside if statement, we can be smarter than that.

## Test Plan
Fixed existing and added new mdtests.
I went through some of mypy primer results and they look fine

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
TomerBin 2025-05-16 14:59:21 +03:00 committed by GitHub
parent 9ae698fe30
commit 9910ec700c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 317 additions and 75 deletions

View file

@ -7,14 +7,14 @@ Similarly, in `and` expressions, if the left-hand side is falsy, the right-hand
evaluated.
```py
def _(flag: bool):
if flag or (x := 1):
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: Literal[1]
def _(flag: bool, number: int):
flag or (y := number)
# error: [possibly-unresolved-reference]
reveal_type(y) # revealed: int
if flag and (x := 1):
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: Literal[1]
flag and (x := number)
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int
```
## First expression is always evaluated
@ -65,3 +65,156 @@ def _(flag1: bool, flag2: bool):
# error: [possibly-unresolved-reference]
reveal_type(z) # revealed: Literal[1]
```
## Inside if-else blocks, we can sometimes know that short-circuit couldn't happen
When if-test contains `And` condition, in the scope of if-body we can be sure that the test is
truthy and therefore short-circuiting couldn't happen. Similarly, when if-test contains `Or`
condition, in the scope of if-else we can be sure that the test is falsy, and therefore
short-circuiting couldn't happen.
### And
```py
def _(flag: bool, number: int):
if flag and (x := number):
# x must be defined here
reveal_type(x) # revealed: int & ~AlwaysFalsy
else:
# TODO: could be int & AlwaysFalsy
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int
```
### Or
```py
def _(flag: bool, number: int):
if flag or (x := number):
# TODO: could be int & AlwaysTruthy
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int
else:
# x must be defined here
reveal_type(x) # revealed: int & ~AlwaysTruthy
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int
```
### Elif
```py
def _(flag: bool, flag2: bool, number: int):
if flag or (x := number):
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int
elif flag2 or (y := number):
# x must be defined here
reveal_type(x) # revealed: int & ~AlwaysTruthy
# error: [possibly-unresolved-reference]
reveal_type(y) # revealed: int
else:
# x and y must be defined here
reveal_type(x) # revealed: int & ~AlwaysTruthy
reveal_type(y) # revealed: int & ~AlwaysTruthy
if flag or (x := number):
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int
elif flag2 and (y := number):
# x must be defined here
reveal_type(x) # revealed: int & ~AlwaysTruthy
reveal_type(y) # revealed: int & ~AlwaysFalsy
else:
# x must be defined here
reveal_type(x) # revealed: int & ~AlwaysTruthy
# error: [possibly-unresolved-reference]
reveal_type(y) # revealed: int
if flag and (x := number):
reveal_type(x) # revealed: int & ~AlwaysFalsy
elif flag2 or (y := number):
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int
# error: [possibly-unresolved-reference]
reveal_type(y) # revealed: int
else:
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int
reveal_type(y) # revealed: int & ~AlwaysTruthy
```
### Nested boolean expression
```py
def _(flag: bool, number: int):
# error: [possibly-unresolved-reference]
(flag or (x := number)) and reveal_type(x) # revealed: int
def _(flag: bool, number: int):
# x must be defined here
(flag or (x := number)) or reveal_type(x) # revealed: int & ~AlwaysTruthy
def _(flag: bool, flag_2: bool, number: int):
if flag and (flag_2 and (x := number)):
# x must be defined here
reveal_type(x) # revealed: int & ~AlwaysFalsy
def _(flag: bool, flag_2: bool, number: int):
if flag and (flag_2 or (x := number)):
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int
else:
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int
def _(flag: bool, flag_2: bool, number: int):
if flag or (flag_2 or (x := number)):
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int
else:
# x must be defined here
reveal_type(x) # revealed: int & ~AlwaysTruthy
```
## This logic can be applied in additional cases that aren't supported yet
### If Expression
```py
def _(flag: bool, number: int):
# TODO: x must be defined here
# error: [possibly-unresolved-reference]
reveal_type(x) if flag and (x := number) else None # revealed: int & ~AlwaysFalsy
```
### While Statement
```py
def _(flag: bool, number: int):
while flag and (x := number):
# TODO: x must be defined here
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int & ~AlwaysFalsy
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int
def _(flag: bool, number: int):
while flag or (x := number):
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int
# TODO: x must be defined here
# error: [possibly-unresolved-reference]
reveal_type(x) # revealed: int & ~AlwaysTruthy
```

View file

@ -52,7 +52,7 @@ def _(x: A | B):
if False and isinstance(x, A):
# TODO: should emit an `unreachable code` diagnostic
reveal_type(x) # revealed: A
reveal_type(x) # revealed: Never
else:
reveal_type(x) # revealed: A | B
@ -65,7 +65,7 @@ def _(x: A | B):
reveal_type(x) # revealed: A | B
else:
# TODO: should emit an `unreachable code` diagnostic
reveal_type(x) # revealed: B & ~A
reveal_type(x) # revealed: Never
reveal_type(x) # revealed: A | B
```

View file

@ -10,7 +10,7 @@ use ruff_db::source::{SourceText, source_text};
use ruff_index::IndexVec;
use ruff_python_ast::name::Name;
use ruff_python_ast::visitor::{Visitor, walk_expr, walk_pattern, walk_stmt};
use ruff_python_ast::{self as ast, PySourceType, PythonVersion};
use ruff_python_ast::{self as ast, BoolOp, Expr, PySourceType, PythonVersion};
use ruff_python_parser::semantic_errors::{
SemanticSyntaxChecker, SemanticSyntaxContext, SemanticSyntaxError, SemanticSyntaxErrorKind,
};
@ -71,6 +71,55 @@ struct ScopeInfo {
current_loop: Option<Loop>,
}
enum TestFlowSnapshots {
BooleanExprTest {
maybe_short_circuit: FlowSnapshot,
no_short_circuit: FlowSnapshot,
op: BoolOp,
},
Default(FlowSnapshot),
}
impl TestFlowSnapshots {
fn flow(&self) -> &FlowSnapshot {
match self {
TestFlowSnapshots::Default(snapshot) => snapshot,
TestFlowSnapshots::BooleanExprTest {
maybe_short_circuit,
..
} => maybe_short_circuit,
}
}
fn falsy_flow(&self) -> &FlowSnapshot {
match self {
TestFlowSnapshots::Default(flow_control) => flow_control,
TestFlowSnapshots::BooleanExprTest {
maybe_short_circuit,
no_short_circuit,
op,
} => match op {
BoolOp::And => maybe_short_circuit,
BoolOp::Or => no_short_circuit,
},
}
}
fn truthy_flow(&self) -> &FlowSnapshot {
match self {
TestFlowSnapshots::Default(flow_control) => flow_control,
TestFlowSnapshots::BooleanExprTest {
maybe_short_circuit,
no_short_circuit,
op,
} => match op {
BoolOp::And => no_short_circuit,
BoolOp::Or => maybe_short_circuit,
},
}
}
}
pub(super) struct SemanticIndexBuilder<'db> {
// Builder state
db: &'db dyn Db,
@ -1512,8 +1561,8 @@ where
}
}
ast::Stmt::If(node) => {
self.visit_expr(&node.test);
let mut no_branch_taken = self.flow_snapshot();
let mut after_test = self.visit_test_expr(&node.test);
self.flow_restore(after_test.truthy_flow().clone());
let mut last_predicate = self.record_expression_narrowing_constraint(&node.test);
let mut reachability_constraint =
self.record_reachability_constraint(last_predicate);
@ -1544,14 +1593,13 @@ where
post_clauses.push(self.flow_snapshot());
// we can only take an elif/else branch if none of the previous ones were
// taken
self.flow_restore(no_branch_taken.clone());
self.flow_restore(after_test.falsy_flow().clone());
self.record_negated_narrowing_constraint(last_predicate);
self.record_negated_reachability_constraint(reachability_constraint);
let elif_predicate = if let Some(elif_test) = clause_test {
self.visit_expr(elif_test);
// A test expression is evaluated whether the branch is taken or not
no_branch_taken = self.flow_snapshot();
after_test = self.visit_test_expr(elif_test);
self.flow_restore(after_test.truthy_flow().clone());
reachability_constraint =
self.record_reachability_constraint(last_predicate);
let predicate = self.record_expression_narrowing_constraint(elif_test);
@ -1576,7 +1624,7 @@ where
self.flow_merge(post_clause_state);
}
self.simplify_visibility_constraints(no_branch_taken);
self.simplify_visibility_constraints(after_test.flow().clone());
}
ast::Stmt::While(ast::StmtWhile {
test,
@ -1960,13 +2008,7 @@ where
}
fn visit_expr(&mut self, expr: &'ast ast::Expr) {
self.with_semantic_checker(|semantic, context| semantic.visit_expr(expr, context));
self.scopes_by_expression
.insert(expr.into(), self.current_scope());
self.current_ast_ids().record_expression(expr);
let node_key = NodeKey::from_node(expr);
let node_key = self.prepare_expr(expr);
match expr {
ast::Expr::Name(ast::ExprName { id, ctx, .. }) => {
@ -2185,57 +2227,7 @@ where
range: _,
op,
}) => {
let pre_op = self.flow_snapshot();
let mut snapshots = vec![];
let mut visibility_constraints = vec![];
for (index, value) in values.iter().enumerate() {
self.visit_expr(value);
for vid in &visibility_constraints {
self.record_visibility_constraint_id(*vid);
}
// For the last value, we don't need to model control flow. There is no short-circuiting
// anymore.
if index < values.len() - 1 {
let predicate = self.build_predicate(value);
let predicate_id = match op {
ast::BoolOp::And => self.add_predicate(predicate),
ast::BoolOp::Or => self.add_negated_predicate(predicate),
};
let visibility_constraint = self
.current_visibility_constraints_mut()
.add_atom(predicate_id);
let after_expr = self.flow_snapshot();
// We first model the short-circuiting behavior. We take the short-circuit
// path here if all of the previous short-circuit paths were not taken, so
// we record all previously existing visibility constraints, and negate the
// one for the current expression.
for vid in &visibility_constraints {
self.record_visibility_constraint_id(*vid);
}
self.record_negated_visibility_constraint(visibility_constraint);
snapshots.push(self.flow_snapshot());
// Then we model the non-short-circuiting behavior. Here, we need to delay
// the application of the visibility constraint until after the expression
// has been evaluated, so we only push it onto the stack here.
self.flow_restore(after_expr);
self.record_narrowing_constraint_id(predicate_id);
self.record_reachability_constraint_id(predicate_id);
visibility_constraints.push(visibility_constraint);
}
}
for snapshot in snapshots {
self.flow_merge(snapshot);
}
self.simplify_visibility_constraints(pre_op);
self.visit_bool_op_expr(values, *op);
}
ast::Expr::Attribute(ast::ExprAttribute {
value: object,
@ -2428,6 +2420,103 @@ where
}
}
impl<'db> SemanticIndexBuilder<'db> {
fn visit_test_expr(&mut self, test_expr: &'db Expr) -> TestFlowSnapshots {
match test_expr {
ast::Expr::BoolOp(ast::ExprBoolOp {
values,
range: _,
op,
}) => {
self.prepare_expr(test_expr);
self.visit_bool_op_expr(values, *op)
}
_ => {
self.visit_expr(test_expr);
TestFlowSnapshots::Default(self.flow_snapshot())
}
}
}
}
impl<'db> SemanticIndexBuilder<'db> {
fn visit_bool_op_expr(&mut self, values: &'db [Expr], op: BoolOp) -> TestFlowSnapshots {
let pre_op = self.flow_snapshot();
let mut short_circuits = vec![];
let mut visibility_constraints = vec![];
for (index, value) in values.iter().enumerate() {
let after_test = self.visit_test_expr(value);
self.flow_restore(match op {
ast::BoolOp::And => after_test.truthy_flow().clone(),
ast::BoolOp::Or => after_test.falsy_flow().clone(),
});
for vid in &visibility_constraints {
self.record_visibility_constraint_id(*vid);
}
// For the last value, we don't need to model control flow. There is no short-circuiting
// anymore.
if index < values.len() - 1 {
let predicate = self.build_predicate(value);
let predicate_id = match op {
ast::BoolOp::And => self.add_predicate(predicate),
ast::BoolOp::Or => self.add_negated_predicate(predicate),
};
let visibility_constraint = self
.current_visibility_constraints_mut()
.add_atom(predicate_id);
let after_expr = self.flow_snapshot();
// We first model the short-circuiting behavior. We take the short-circuit
// path here if all of the previous short-circuit paths were not taken, so
// we record all previously existing visibility constraints, and negate the
// one for the current expression.
for vid in &visibility_constraints {
self.record_visibility_constraint_id(*vid);
}
self.record_negated_visibility_constraint(visibility_constraint);
short_circuits.push(self.flow_snapshot());
// Then we model the non-short-circuiting behavior. Here, we need to delay
// the application of the visibility constraint until after the expression
// has been evaluated, so we only push it onto the stack here.
self.flow_restore(after_expr);
self.record_narrowing_constraint_id(predicate_id);
self.record_reachability_constraint_id(predicate_id);
visibility_constraints.push(visibility_constraint);
}
}
let no_short_circuit = self.flow_snapshot();
for snapshot in short_circuits.clone() {
self.flow_merge(snapshot);
}
let maybe_short_circuit = self.flow_snapshot();
self.simplify_visibility_constraints(pre_op);
TestFlowSnapshots::BooleanExprTest {
maybe_short_circuit,
no_short_circuit,
op,
}
}
}
impl SemanticIndexBuilder<'_> {
fn prepare_expr(&mut self, expr: &Expr) -> NodeKey {
self.with_semantic_checker(|semantic, context| semantic.visit_expr(expr, context));
self.scopes_by_expression
.insert(expr.into(), self.current_scope());
self.current_ast_ids().record_expression(expr);
NodeKey::from_node(expr)
}
}
impl SemanticSyntaxContext for SemanticIndexBuilder<'_> {
fn future_annotations_or_stub(&self) -> bool {
self.has_future_annotations