[ty] Eagerly simplify 'True' and 'False' constraints (#18998)

## Summary

Simplifies literal `True` and `False` conditions to `ALWAYS_TRUE` /
`ALWAYS_FALSE` during semantic index building. This allows us to eagerly
evaluate more constraints, which should help with performance (looks
like there is a tiny 1% improvement in instrumented benchmarks), but
also allows us to eliminate definitely-unreachable branches in
control-flow merging. This can lead to better type inference in some
cases because it allows us to retain narrowing constraints without
solving https://github.com/astral-sh/ty/issues/690 first:
```py
def _(c: int | None):
    if c is None:
        assert False
    
    reveal_type(c)  # int, previously: int | None
```

closes https://github.com/astral-sh/ty/issues/713

## Test Plan

* Regression test for https://github.com/astral-sh/ty/issues/713
* Made sure that all ecosystem diffs trace back to removed false
positives
This commit is contained in:
David Peter 2025-06-30 13:11:52 +02:00 committed by GitHub
parent 54769ac9f9
commit db3dcd8ad6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 129 additions and 43 deletions

View file

@ -35,8 +35,8 @@ use crate::semantic_index::place::{
PlaceExprWithFlags, PlaceTableBuilder, Scope, ScopeId, ScopeKind, ScopedPlaceId,
};
use crate::semantic_index::predicate::{
PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, ScopedPredicateId,
StarImportPlaceholderPredicate,
PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, PredicateOrLiteral,
ScopedPredicateId, StarImportPlaceholderPredicate,
};
use crate::semantic_index::re_exports::exported_names;
use crate::semantic_index::reachability_constraints::{
@ -535,29 +535,34 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
fn record_expression_narrowing_constraint(
&mut self,
precide_node: &ast::Expr,
) -> Predicate<'db> {
) -> PredicateOrLiteral<'db> {
let predicate = self.build_predicate(precide_node);
self.record_narrowing_constraint(predicate);
predicate
}
fn build_predicate(&mut self, predicate_node: &ast::Expr) -> Predicate<'db> {
fn build_predicate(&mut self, predicate_node: &ast::Expr) -> PredicateOrLiteral<'db> {
let expression = self.add_standalone_expression(predicate_node);
Predicate {
node: PredicateNode::Expression(expression),
is_positive: true,
if let Some(boolean_literal) = predicate_node.as_boolean_literal_expr() {
PredicateOrLiteral::Literal(boolean_literal.value)
} else {
PredicateOrLiteral::Predicate(Predicate {
node: PredicateNode::Expression(expression),
is_positive: true,
})
}
}
/// Adds a new predicate to the list of all predicates, but does not record it. Returns the
/// predicate ID for later recording using
/// [`SemanticIndexBuilder::record_narrowing_constraint_id`].
fn add_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId {
fn add_predicate(&mut self, predicate: PredicateOrLiteral<'db>) -> ScopedPredicateId {
self.current_use_def_map_mut().add_predicate(predicate)
}
/// Negates a predicate and adds it to the list of all predicates, does not record it.
fn add_negated_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId {
fn add_negated_predicate(&mut self, predicate: PredicateOrLiteral<'db>) -> ScopedPredicateId {
self.current_use_def_map_mut()
.add_predicate(predicate.negated())
}
@ -569,7 +574,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
}
/// Adds and records a narrowing constraint, i.e. adds it to all live bindings.
fn record_narrowing_constraint(&mut self, predicate: Predicate<'db>) {
fn record_narrowing_constraint(&mut self, predicate: PredicateOrLiteral<'db>) {
let use_def = self.current_use_def_map_mut();
let predicate_id = use_def.add_predicate(predicate);
use_def.record_narrowing_constraint(predicate_id);
@ -579,7 +584,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
/// bindings.
fn record_negated_narrowing_constraint(
&mut self,
predicate: Predicate<'db>,
predicate: PredicateOrLiteral<'db>,
) -> ScopedPredicateId {
let id = self.add_negated_predicate(predicate);
self.record_narrowing_constraint_id(id);
@ -603,7 +608,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
/// we know that all statements that follow in this path of control flow will be unreachable.
fn record_reachability_constraint(
&mut self,
predicate: Predicate<'db>,
predicate: PredicateOrLiteral<'db>,
) -> ScopedReachabilityConstraintId {
let predicate_id = self.add_predicate(predicate);
self.record_reachability_constraint_id(predicate_id)
@ -617,6 +622,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
let reachability_constraint = self
.current_reachability_constraints_mut()
.add_atom(predicate_id);
self.current_use_def_map_mut()
.record_reachability_constraint(reachability_constraint);
reachability_constraint
@ -681,7 +687,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
subject: Expression<'db>,
pattern: &ast::Pattern,
guard: Option<&ast::Expr>,
) -> Predicate<'db> {
) -> PredicateOrLiteral<'db> {
// This is called for the top-level pattern of each match arm. We need to create a
// standalone expression for each arm of a match statement, since they can introduce
// constraints on the match subject. (Or more accurately, for the match arm's pattern,
@ -705,10 +711,10 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
guard,
countme::Count::default(),
);
let predicate = Predicate {
let predicate = PredicateOrLiteral::Predicate(Predicate {
node: PredicateNode::Pattern(pattern_predicate),
is_positive: true,
};
});
self.record_narrowing_constraint(predicate);
predicate
}
@ -1653,10 +1659,10 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
self.record_ambiguous_reachability();
self.visit_expr(guard);
let post_guard_eval = self.flow_snapshot();
let predicate = Predicate {
let predicate = PredicateOrLiteral::Predicate(Predicate {
node: PredicateNode::Expression(guard_expr),
is_positive: true,
};
});
self.record_negated_narrowing_constraint(predicate);
let match_success_guard_failure = self.flow_snapshot();
self.flow_restore(post_guard_eval);

View file

@ -8,7 +8,7 @@
//! static reachability of a binding, and the reachability of a statement or expression.
use ruff_db::files::File;
use ruff_index::{IndexVec, newtype_index};
use ruff_index::{Idx, IndexVec};
use ruff_python_ast::Singleton;
use crate::db::Db;
@ -17,9 +17,42 @@ use crate::semantic_index::global_scope;
use crate::semantic_index::place::{FileScopeId, ScopeId, ScopedPlaceId};
// A scoped identifier for each `Predicate` in a scope.
#[newtype_index]
#[derive(Ord, PartialOrd, get_size2::GetSize)]
pub(crate) struct ScopedPredicateId;
#[derive(Clone, Debug, Copy, PartialOrd, Ord, PartialEq, Eq, Hash, get_size2::GetSize)]
pub(crate) struct ScopedPredicateId(u32);
impl ScopedPredicateId {
/// A special ID that is used for an "always true" predicate.
pub(crate) const ALWAYS_TRUE: ScopedPredicateId = ScopedPredicateId(0xffff_ffff);
/// A special ID that is used for an "always false" predicate.
pub(crate) const ALWAYS_FALSE: ScopedPredicateId = ScopedPredicateId(0xffff_fffe);
const SMALLEST_TERMINAL: ScopedPredicateId = Self::ALWAYS_FALSE;
fn is_terminal(self) -> bool {
self >= Self::SMALLEST_TERMINAL
}
#[cfg(test)]
pub(crate) fn as_u32(self) -> u32 {
self.0
}
}
impl Idx for ScopedPredicateId {
#[inline]
fn new(value: usize) -> Self {
assert!(value <= (Self::SMALLEST_TERMINAL.0 as usize));
#[expect(clippy::cast_possible_truncation)]
Self(value as u32)
}
#[inline]
fn index(self) -> usize {
debug_assert!(!self.is_terminal());
self.0 as usize
}
}
// A collection of predicates for a given scope.
pub(crate) type Predicates<'db> = IndexVec<ScopedPredicateId, Predicate<'db>>;
@ -49,11 +82,22 @@ pub(crate) struct Predicate<'db> {
pub(crate) is_positive: bool,
}
impl Predicate<'_> {
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub(crate) enum PredicateOrLiteral<'db> {
Literal(bool),
Predicate(Predicate<'db>),
}
impl PredicateOrLiteral<'_> {
pub(crate) fn negated(self) -> Self {
Self {
node: self.node,
is_positive: !self.is_positive,
match self {
PredicateOrLiteral::Literal(value) => PredicateOrLiteral::Literal(!value),
PredicateOrLiteral::Predicate(Predicate { node, is_positive }) => {
PredicateOrLiteral::Predicate(Predicate {
node,
is_positive: !is_positive,
})
}
}
}
}
@ -169,11 +213,11 @@ impl<'db> StarImportPlaceholderPredicate<'db> {
}
}
impl<'db> From<StarImportPlaceholderPredicate<'db>> for Predicate<'db> {
impl<'db> From<StarImportPlaceholderPredicate<'db>> for PredicateOrLiteral<'db> {
fn from(predicate: StarImportPlaceholderPredicate<'db>) -> Self {
Predicate {
PredicateOrLiteral::Predicate(Predicate {
node: PredicateNode::StarImportPlaceholder(predicate),
is_positive: true,
}
})
}
}

View file

@ -388,12 +388,18 @@ impl ReachabilityConstraintsBuilder {
&mut self,
predicate: ScopedPredicateId,
) -> ScopedReachabilityConstraintId {
self.add_interior(InteriorNode {
atom: predicate,
if_true: ALWAYS_TRUE,
if_ambiguous: AMBIGUOUS,
if_false: ALWAYS_FALSE,
})
if predicate == ScopedPredicateId::ALWAYS_FALSE {
ScopedReachabilityConstraintId::ALWAYS_FALSE
} else if predicate == ScopedPredicateId::ALWAYS_TRUE {
ScopedReachabilityConstraintId::ALWAYS_TRUE
} else {
self.add_interior(InteriorNode {
atom: predicate,
if_true: ALWAYS_TRUE,
if_ambiguous: AMBIGUOUS,
if_false: ALWAYS_FALSE,
})
}
}
/// Adds a new reachability constraint that is the ternary NOT of an existing one.

View file

@ -247,7 +247,8 @@ use crate::semantic_index::place::{
FileScopeId, PlaceExpr, PlaceExprWithFlags, ScopeKind, ScopedPlaceId,
};
use crate::semantic_index::predicate::{
Predicate, Predicates, PredicatesBuilder, ScopedPredicateId, StarImportPlaceholderPredicate,
Predicate, PredicateOrLiteral, Predicates, PredicatesBuilder, ScopedPredicateId,
StarImportPlaceholderPredicate,
};
use crate::semantic_index::reachability_constraints::{
ReachabilityConstraints, ReachabilityConstraintsBuilder, ScopedReachabilityConstraintId,
@ -805,11 +806,25 @@ impl<'db> UseDefMapBuilder<'db> {
);
}
pub(super) fn add_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId {
self.predicates.add_predicate(predicate)
pub(super) fn add_predicate(
&mut self,
predicate: PredicateOrLiteral<'db>,
) -> ScopedPredicateId {
match predicate {
PredicateOrLiteral::Predicate(predicate) => self.predicates.add_predicate(predicate),
PredicateOrLiteral::Literal(true) => ScopedPredicateId::ALWAYS_TRUE,
PredicateOrLiteral::Literal(false) => ScopedPredicateId::ALWAYS_FALSE,
}
}
pub(super) fn record_narrowing_constraint(&mut self, predicate: ScopedPredicateId) {
if predicate == ScopedPredicateId::ALWAYS_TRUE
|| predicate == ScopedPredicateId::ALWAYS_FALSE
{
// No need to record a narrowing constraint for `True` or `False`.
return;
}
let narrowing_constraint = predicate.into();
for state in &mut self.place_states {
state

View file

@ -431,6 +431,7 @@ impl PlaceState {
#[cfg(test)]
mod tests {
use super::*;
use ruff_index::Idx;
use crate::semantic_index::predicate::ScopedPredicateId;
@ -514,7 +515,7 @@ mod tests {
false,
true,
);
let predicate = ScopedPredicateId::from_u32(0).into();
let predicate = ScopedPredicateId::new(0).into();
sym.record_narrowing_constraint(&mut narrowing_constraints, predicate);
assert_bindings(&narrowing_constraints, &sym, &["1<0>"]);
@ -533,7 +534,7 @@ mod tests {
false,
true,
);
let predicate = ScopedPredicateId::from_u32(0).into();
let predicate = ScopedPredicateId::new(0).into();
sym1a.record_narrowing_constraint(&mut narrowing_constraints, predicate);
let mut sym1b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE);
@ -543,7 +544,7 @@ mod tests {
false,
true,
);
let predicate = ScopedPredicateId::from_u32(0).into();
let predicate = ScopedPredicateId::new(0).into();
sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate);
sym1a.merge(
@ -562,7 +563,7 @@ mod tests {
false,
true,
);
let predicate = ScopedPredicateId::from_u32(1).into();
let predicate = ScopedPredicateId::new(1).into();
sym2a.record_narrowing_constraint(&mut narrowing_constraints, predicate);
let mut sym1b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE);
@ -572,7 +573,7 @@ mod tests {
false,
true,
);
let predicate = ScopedPredicateId::from_u32(2).into();
let predicate = ScopedPredicateId::new(2).into();
sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate);
sym2a.merge(
@ -591,7 +592,7 @@ mod tests {
false,
true,
);
let predicate = ScopedPredicateId::from_u32(3).into();
let predicate = ScopedPredicateId::new(3).into();
sym3a.record_narrowing_constraint(&mut narrowing_constraints, predicate);
let sym2b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE);