[ty] Eagerly simplify 'True' and 'False' constraints (#18998)

## Summary

Simplifies literal `True` and `False` conditions to `ALWAYS_TRUE` /
`ALWAYS_FALSE` during semantic index building. This allows us to eagerly
evaluate more constraints, which should help with performance (looks
like there is a tiny 1% improvement in instrumented benchmarks), but
also allows us to eliminate definitely-unreachable branches in
control-flow merging. This can lead to better type inference in some
cases because it allows us to retain narrowing constraints without
solving https://github.com/astral-sh/ty/issues/690 first:
```py
def _(c: int | None):
    if c is None:
        assert False
    
    reveal_type(c)  # int, previously: int | None
```

closes https://github.com/astral-sh/ty/issues/713

## Test Plan

* Regression test for https://github.com/astral-sh/ty/issues/713
* Made sure that all ecosystem diffs trace back to removed false
positives
This commit is contained in:
David Peter 2025-06-30 13:11:52 +02:00 committed by GitHub
parent 54769ac9f9
commit db3dcd8ad6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 129 additions and 43 deletions

View file

@ -59,3 +59,17 @@ while x != 1:
x = next_item() x = next_item()
``` ```
## With `break` statements
```py
def next_item() -> int | None:
return 1
while True:
x = next_item()
if x is not None:
break
reveal_type(x) # revealed: int
```

View file

@ -35,8 +35,8 @@ use crate::semantic_index::place::{
PlaceExprWithFlags, PlaceTableBuilder, Scope, ScopeId, ScopeKind, ScopedPlaceId, PlaceExprWithFlags, PlaceTableBuilder, Scope, ScopeId, ScopeKind, ScopedPlaceId,
}; };
use crate::semantic_index::predicate::{ use crate::semantic_index::predicate::{
PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, ScopedPredicateId, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, PredicateOrLiteral,
StarImportPlaceholderPredicate, ScopedPredicateId, StarImportPlaceholderPredicate,
}; };
use crate::semantic_index::re_exports::exported_names; use crate::semantic_index::re_exports::exported_names;
use crate::semantic_index::reachability_constraints::{ use crate::semantic_index::reachability_constraints::{
@ -535,29 +535,34 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
fn record_expression_narrowing_constraint( fn record_expression_narrowing_constraint(
&mut self, &mut self,
precide_node: &ast::Expr, precide_node: &ast::Expr,
) -> Predicate<'db> { ) -> PredicateOrLiteral<'db> {
let predicate = self.build_predicate(precide_node); let predicate = self.build_predicate(precide_node);
self.record_narrowing_constraint(predicate); self.record_narrowing_constraint(predicate);
predicate predicate
} }
fn build_predicate(&mut self, predicate_node: &ast::Expr) -> Predicate<'db> { fn build_predicate(&mut self, predicate_node: &ast::Expr) -> PredicateOrLiteral<'db> {
let expression = self.add_standalone_expression(predicate_node); let expression = self.add_standalone_expression(predicate_node);
Predicate {
if let Some(boolean_literal) = predicate_node.as_boolean_literal_expr() {
PredicateOrLiteral::Literal(boolean_literal.value)
} else {
PredicateOrLiteral::Predicate(Predicate {
node: PredicateNode::Expression(expression), node: PredicateNode::Expression(expression),
is_positive: true, is_positive: true,
})
} }
} }
/// Adds a new predicate to the list of all predicates, but does not record it. Returns the /// Adds a new predicate to the list of all predicates, but does not record it. Returns the
/// predicate ID for later recording using /// predicate ID for later recording using
/// [`SemanticIndexBuilder::record_narrowing_constraint_id`]. /// [`SemanticIndexBuilder::record_narrowing_constraint_id`].
fn add_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId { fn add_predicate(&mut self, predicate: PredicateOrLiteral<'db>) -> ScopedPredicateId {
self.current_use_def_map_mut().add_predicate(predicate) self.current_use_def_map_mut().add_predicate(predicate)
} }
/// Negates a predicate and adds it to the list of all predicates, does not record it. /// Negates a predicate and adds it to the list of all predicates, does not record it.
fn add_negated_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId { fn add_negated_predicate(&mut self, predicate: PredicateOrLiteral<'db>) -> ScopedPredicateId {
self.current_use_def_map_mut() self.current_use_def_map_mut()
.add_predicate(predicate.negated()) .add_predicate(predicate.negated())
} }
@ -569,7 +574,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
} }
/// Adds and records a narrowing constraint, i.e. adds it to all live bindings. /// Adds and records a narrowing constraint, i.e. adds it to all live bindings.
fn record_narrowing_constraint(&mut self, predicate: Predicate<'db>) { fn record_narrowing_constraint(&mut self, predicate: PredicateOrLiteral<'db>) {
let use_def = self.current_use_def_map_mut(); let use_def = self.current_use_def_map_mut();
let predicate_id = use_def.add_predicate(predicate); let predicate_id = use_def.add_predicate(predicate);
use_def.record_narrowing_constraint(predicate_id); use_def.record_narrowing_constraint(predicate_id);
@ -579,7 +584,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
/// bindings. /// bindings.
fn record_negated_narrowing_constraint( fn record_negated_narrowing_constraint(
&mut self, &mut self,
predicate: Predicate<'db>, predicate: PredicateOrLiteral<'db>,
) -> ScopedPredicateId { ) -> ScopedPredicateId {
let id = self.add_negated_predicate(predicate); let id = self.add_negated_predicate(predicate);
self.record_narrowing_constraint_id(id); self.record_narrowing_constraint_id(id);
@ -603,7 +608,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
/// we know that all statements that follow in this path of control flow will be unreachable. /// we know that all statements that follow in this path of control flow will be unreachable.
fn record_reachability_constraint( fn record_reachability_constraint(
&mut self, &mut self,
predicate: Predicate<'db>, predicate: PredicateOrLiteral<'db>,
) -> ScopedReachabilityConstraintId { ) -> ScopedReachabilityConstraintId {
let predicate_id = self.add_predicate(predicate); let predicate_id = self.add_predicate(predicate);
self.record_reachability_constraint_id(predicate_id) self.record_reachability_constraint_id(predicate_id)
@ -617,6 +622,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
let reachability_constraint = self let reachability_constraint = self
.current_reachability_constraints_mut() .current_reachability_constraints_mut()
.add_atom(predicate_id); .add_atom(predicate_id);
self.current_use_def_map_mut() self.current_use_def_map_mut()
.record_reachability_constraint(reachability_constraint); .record_reachability_constraint(reachability_constraint);
reachability_constraint reachability_constraint
@ -681,7 +687,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
subject: Expression<'db>, subject: Expression<'db>,
pattern: &ast::Pattern, pattern: &ast::Pattern,
guard: Option<&ast::Expr>, guard: Option<&ast::Expr>,
) -> Predicate<'db> { ) -> PredicateOrLiteral<'db> {
// This is called for the top-level pattern of each match arm. We need to create a // This is called for the top-level pattern of each match arm. We need to create a
// standalone expression for each arm of a match statement, since they can introduce // standalone expression for each arm of a match statement, since they can introduce
// constraints on the match subject. (Or more accurately, for the match arm's pattern, // constraints on the match subject. (Or more accurately, for the match arm's pattern,
@ -705,10 +711,10 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
guard, guard,
countme::Count::default(), countme::Count::default(),
); );
let predicate = Predicate { let predicate = PredicateOrLiteral::Predicate(Predicate {
node: PredicateNode::Pattern(pattern_predicate), node: PredicateNode::Pattern(pattern_predicate),
is_positive: true, is_positive: true,
}; });
self.record_narrowing_constraint(predicate); self.record_narrowing_constraint(predicate);
predicate predicate
} }
@ -1653,10 +1659,10 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
self.record_ambiguous_reachability(); self.record_ambiguous_reachability();
self.visit_expr(guard); self.visit_expr(guard);
let post_guard_eval = self.flow_snapshot(); let post_guard_eval = self.flow_snapshot();
let predicate = Predicate { let predicate = PredicateOrLiteral::Predicate(Predicate {
node: PredicateNode::Expression(guard_expr), node: PredicateNode::Expression(guard_expr),
is_positive: true, is_positive: true,
}; });
self.record_negated_narrowing_constraint(predicate); self.record_negated_narrowing_constraint(predicate);
let match_success_guard_failure = self.flow_snapshot(); let match_success_guard_failure = self.flow_snapshot();
self.flow_restore(post_guard_eval); self.flow_restore(post_guard_eval);

View file

@ -8,7 +8,7 @@
//! static reachability of a binding, and the reachability of a statement or expression. //! static reachability of a binding, and the reachability of a statement or expression.
use ruff_db::files::File; use ruff_db::files::File;
use ruff_index::{IndexVec, newtype_index}; use ruff_index::{Idx, IndexVec};
use ruff_python_ast::Singleton; use ruff_python_ast::Singleton;
use crate::db::Db; use crate::db::Db;
@ -17,9 +17,42 @@ use crate::semantic_index::global_scope;
use crate::semantic_index::place::{FileScopeId, ScopeId, ScopedPlaceId}; use crate::semantic_index::place::{FileScopeId, ScopeId, ScopedPlaceId};
// A scoped identifier for each `Predicate` in a scope. // A scoped identifier for each `Predicate` in a scope.
#[newtype_index] #[derive(Clone, Debug, Copy, PartialOrd, Ord, PartialEq, Eq, Hash, get_size2::GetSize)]
#[derive(Ord, PartialOrd, get_size2::GetSize)] pub(crate) struct ScopedPredicateId(u32);
pub(crate) struct ScopedPredicateId;
impl ScopedPredicateId {
/// A special ID that is used for an "always true" predicate.
pub(crate) const ALWAYS_TRUE: ScopedPredicateId = ScopedPredicateId(0xffff_ffff);
/// A special ID that is used for an "always false" predicate.
pub(crate) const ALWAYS_FALSE: ScopedPredicateId = ScopedPredicateId(0xffff_fffe);
const SMALLEST_TERMINAL: ScopedPredicateId = Self::ALWAYS_FALSE;
fn is_terminal(self) -> bool {
self >= Self::SMALLEST_TERMINAL
}
#[cfg(test)]
pub(crate) fn as_u32(self) -> u32 {
self.0
}
}
impl Idx for ScopedPredicateId {
#[inline]
fn new(value: usize) -> Self {
assert!(value <= (Self::SMALLEST_TERMINAL.0 as usize));
#[expect(clippy::cast_possible_truncation)]
Self(value as u32)
}
#[inline]
fn index(self) -> usize {
debug_assert!(!self.is_terminal());
self.0 as usize
}
}
// A collection of predicates for a given scope. // A collection of predicates for a given scope.
pub(crate) type Predicates<'db> = IndexVec<ScopedPredicateId, Predicate<'db>>; pub(crate) type Predicates<'db> = IndexVec<ScopedPredicateId, Predicate<'db>>;
@ -49,11 +82,22 @@ pub(crate) struct Predicate<'db> {
pub(crate) is_positive: bool, pub(crate) is_positive: bool,
} }
impl Predicate<'_> { #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub(crate) enum PredicateOrLiteral<'db> {
Literal(bool),
Predicate(Predicate<'db>),
}
impl PredicateOrLiteral<'_> {
pub(crate) fn negated(self) -> Self { pub(crate) fn negated(self) -> Self {
Self { match self {
node: self.node, PredicateOrLiteral::Literal(value) => PredicateOrLiteral::Literal(!value),
is_positive: !self.is_positive, PredicateOrLiteral::Predicate(Predicate { node, is_positive }) => {
PredicateOrLiteral::Predicate(Predicate {
node,
is_positive: !is_positive,
})
}
} }
} }
} }
@ -169,11 +213,11 @@ impl<'db> StarImportPlaceholderPredicate<'db> {
} }
} }
impl<'db> From<StarImportPlaceholderPredicate<'db>> for Predicate<'db> { impl<'db> From<StarImportPlaceholderPredicate<'db>> for PredicateOrLiteral<'db> {
fn from(predicate: StarImportPlaceholderPredicate<'db>) -> Self { fn from(predicate: StarImportPlaceholderPredicate<'db>) -> Self {
Predicate { PredicateOrLiteral::Predicate(Predicate {
node: PredicateNode::StarImportPlaceholder(predicate), node: PredicateNode::StarImportPlaceholder(predicate),
is_positive: true, is_positive: true,
} })
} }
} }

View file

@ -388,6 +388,11 @@ impl ReachabilityConstraintsBuilder {
&mut self, &mut self,
predicate: ScopedPredicateId, predicate: ScopedPredicateId,
) -> ScopedReachabilityConstraintId { ) -> ScopedReachabilityConstraintId {
if predicate == ScopedPredicateId::ALWAYS_FALSE {
ScopedReachabilityConstraintId::ALWAYS_FALSE
} else if predicate == ScopedPredicateId::ALWAYS_TRUE {
ScopedReachabilityConstraintId::ALWAYS_TRUE
} else {
self.add_interior(InteriorNode { self.add_interior(InteriorNode {
atom: predicate, atom: predicate,
if_true: ALWAYS_TRUE, if_true: ALWAYS_TRUE,
@ -395,6 +400,7 @@ impl ReachabilityConstraintsBuilder {
if_false: ALWAYS_FALSE, if_false: ALWAYS_FALSE,
}) })
} }
}
/// Adds a new reachability constraint that is the ternary NOT of an existing one. /// Adds a new reachability constraint that is the ternary NOT of an existing one.
pub(crate) fn add_not_constraint( pub(crate) fn add_not_constraint(

View file

@ -247,7 +247,8 @@ use crate::semantic_index::place::{
FileScopeId, PlaceExpr, PlaceExprWithFlags, ScopeKind, ScopedPlaceId, FileScopeId, PlaceExpr, PlaceExprWithFlags, ScopeKind, ScopedPlaceId,
}; };
use crate::semantic_index::predicate::{ use crate::semantic_index::predicate::{
Predicate, Predicates, PredicatesBuilder, ScopedPredicateId, StarImportPlaceholderPredicate, Predicate, PredicateOrLiteral, Predicates, PredicatesBuilder, ScopedPredicateId,
StarImportPlaceholderPredicate,
}; };
use crate::semantic_index::reachability_constraints::{ use crate::semantic_index::reachability_constraints::{
ReachabilityConstraints, ReachabilityConstraintsBuilder, ScopedReachabilityConstraintId, ReachabilityConstraints, ReachabilityConstraintsBuilder, ScopedReachabilityConstraintId,
@ -805,11 +806,25 @@ impl<'db> UseDefMapBuilder<'db> {
); );
} }
pub(super) fn add_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId { pub(super) fn add_predicate(
self.predicates.add_predicate(predicate) &mut self,
predicate: PredicateOrLiteral<'db>,
) -> ScopedPredicateId {
match predicate {
PredicateOrLiteral::Predicate(predicate) => self.predicates.add_predicate(predicate),
PredicateOrLiteral::Literal(true) => ScopedPredicateId::ALWAYS_TRUE,
PredicateOrLiteral::Literal(false) => ScopedPredicateId::ALWAYS_FALSE,
}
} }
pub(super) fn record_narrowing_constraint(&mut self, predicate: ScopedPredicateId) { pub(super) fn record_narrowing_constraint(&mut self, predicate: ScopedPredicateId) {
if predicate == ScopedPredicateId::ALWAYS_TRUE
|| predicate == ScopedPredicateId::ALWAYS_FALSE
{
// No need to record a narrowing constraint for `True` or `False`.
return;
}
let narrowing_constraint = predicate.into(); let narrowing_constraint = predicate.into();
for state in &mut self.place_states { for state in &mut self.place_states {
state state

View file

@ -431,6 +431,7 @@ impl PlaceState {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use ruff_index::Idx;
use crate::semantic_index::predicate::ScopedPredicateId; use crate::semantic_index::predicate::ScopedPredicateId;
@ -514,7 +515,7 @@ mod tests {
false, false,
true, true,
); );
let predicate = ScopedPredicateId::from_u32(0).into(); let predicate = ScopedPredicateId::new(0).into();
sym.record_narrowing_constraint(&mut narrowing_constraints, predicate); sym.record_narrowing_constraint(&mut narrowing_constraints, predicate);
assert_bindings(&narrowing_constraints, &sym, &["1<0>"]); assert_bindings(&narrowing_constraints, &sym, &["1<0>"]);
@ -533,7 +534,7 @@ mod tests {
false, false,
true, true,
); );
let predicate = ScopedPredicateId::from_u32(0).into(); let predicate = ScopedPredicateId::new(0).into();
sym1a.record_narrowing_constraint(&mut narrowing_constraints, predicate); sym1a.record_narrowing_constraint(&mut narrowing_constraints, predicate);
let mut sym1b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); let mut sym1b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE);
@ -543,7 +544,7 @@ mod tests {
false, false,
true, true,
); );
let predicate = ScopedPredicateId::from_u32(0).into(); let predicate = ScopedPredicateId::new(0).into();
sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate); sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate);
sym1a.merge( sym1a.merge(
@ -562,7 +563,7 @@ mod tests {
false, false,
true, true,
); );
let predicate = ScopedPredicateId::from_u32(1).into(); let predicate = ScopedPredicateId::new(1).into();
sym2a.record_narrowing_constraint(&mut narrowing_constraints, predicate); sym2a.record_narrowing_constraint(&mut narrowing_constraints, predicate);
let mut sym1b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); let mut sym1b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE);
@ -572,7 +573,7 @@ mod tests {
false, false,
true, true,
); );
let predicate = ScopedPredicateId::from_u32(2).into(); let predicate = ScopedPredicateId::new(2).into();
sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate); sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate);
sym2a.merge( sym2a.merge(
@ -591,7 +592,7 @@ mod tests {
false, false,
true, true,
); );
let predicate = ScopedPredicateId::from_u32(3).into(); let predicate = ScopedPredicateId::new(3).into();
sym3a.record_narrowing_constraint(&mut narrowing_constraints, predicate); sym3a.record_narrowing_constraint(&mut narrowing_constraints, predicate);
let sym2b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); let sym2b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE);