more cleanup

This commit is contained in:
Eric Mark Martin 2025-12-20 03:05:52 -05:00
parent f8185ce8be
commit 051894165d
2 changed files with 65 additions and 123 deletions

View file

@ -217,7 +217,6 @@ def is_b(val: object) -> TypeGuard[B]:
def _(x: P):
if isinstance(x, A) or is_b(x):
# currently reveals `(P & A) | (P & B)`, should reveal `(P & A) | B`
reveal_type(x) # revealed: (P & A) | B
```

View file

@ -68,7 +68,7 @@ pub(crate) fn infer_narrowing_constraint<'db>(
PredicateNode::StarImportPlaceholder(_) => return None,
};
if let Some(constraints) = constraints {
constraints.constraints.get(&place).cloned()
constraints.get(&place).cloned()
} else {
None
}
@ -78,7 +78,7 @@ pub(crate) fn infer_narrowing_constraint<'db>(
fn all_narrowing_constraints_for_pattern<'db>(
db: &'db dyn Db,
pattern: PatternPredicate<'db>,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
let module = parsed_module(db, pattern.file(db)).load(db);
NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), true).finish()
}
@ -91,7 +91,7 @@ fn all_narrowing_constraints_for_pattern<'db>(
fn all_narrowing_constraints_for_expression<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
let module = parsed_module(db, expression.file(db)).load(db);
NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), true)
.finish()
@ -105,7 +105,7 @@ fn all_narrowing_constraints_for_expression<'db>(
fn all_negative_narrowing_constraints_for_expression<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
let module = parsed_module(db, expression.file(db)).load(db);
NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), false)
.finish()
@ -115,7 +115,7 @@ fn all_negative_narrowing_constraints_for_expression<'db>(
fn all_negative_narrowing_constraints_for_pattern<'db>(
db: &'db dyn Db,
pattern: PatternPredicate<'db>,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
let module = parsed_module(db, pattern.file(db)).load(db);
NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), false).finish()
}
@ -124,7 +124,7 @@ fn constraints_for_expression_cycle_initial<'db>(
_db: &'db dyn Db,
_id: salsa::Id,
_expression: Expression<'db>,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
None
}
@ -132,7 +132,7 @@ fn negative_constraints_for_expression_cycle_initial<'db>(
_db: &'db dyn Db,
_id: salsa::Id,
_expression: Expression<'db>,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
None
}
@ -283,10 +283,10 @@ impl ClassInfoConstraintFunction {
/// intersecting
///
/// For example, `(Conjunction { constraint: A, typeguard: Some(B) } &
/// Conjunction { constraint: C, typeguard: Some(D)})` evlaluates to
/// Conjunction { constraint: C, typeguard: Some(D)})` evaluates to
/// `Conjunction { constraint: C, typeguard: Some(D) }` because the type guard
/// in the second clobbers the first.
#[derive(Hash, PartialEq, Debug, Eq, Clone, Copy)]
/// in the second conjunct clobbers that in the first.
#[derive(Hash, PartialEq, Debug, Eq, Clone, Copy, salsa::Update, get_size2::GetSize)]
struct Conjunction<'db> {
/// The intersected constraints (represented as a type to intersect the guard with)
constraint: Type<'db>,
@ -294,8 +294,6 @@ struct Conjunction<'db> {
typeguard: Option<Type<'db>>,
}
impl get_size2::GetSize for Conjunction<'_> {}
impl<'db> Conjunction<'db> {
/// Create a new conjunction with just a regular constraint
fn regular(constraint: Type<'db>) -> Self {
@ -343,14 +341,12 @@ impl<'db> Conjunction<'db> {
/// - `f(x) or g(x)` where f returns `TypeIs[A]` and g returns `TypeGuard[B]`
/// => `[Conjunction { constraint: A, typeguard: None }, Conjunction { constraint: object, typeguard: Some(B) }]`
/// => evaluates to `(P & A) | B`, where `P` is our previously-known type
#[derive(Hash, PartialEq, Debug, Eq, Clone)]
#[derive(Hash, PartialEq, Debug, Eq, Clone, salsa::Update, get_size2::GetSize)]
pub(crate) struct NarrowingConstraint<'db> {
/// Disjunction of conjunctions (DNF)
disjuncts: SmallVec<[Conjunction<'db>; 1]>,
}
impl get_size2::GetSize for NarrowingConstraint<'_> {}
impl<'db> NarrowingConstraint<'db> {
/// Create a constraint from a regular (non-`TypeGuard`) type
pub(crate) fn regular(constraint: Type<'db>) -> Self {
@ -366,17 +362,18 @@ impl<'db> NarrowingConstraint<'db> {
}
}
/// Merge two constraints, taking their intersection but respecting `TypeGuard` semantics
/// Merge two constraints, taking their intersection but respecting `TypeGuard` semantics (with `other` winning)
pub(crate) fn merge_constraint_and(&self, other: &Self, db: &'db dyn Db) -> Self {
let mut new_disjuncts = SmallVec::new();
// Distribute AND over OR: (A1 | A2 | ...) AND (B1 | B2 | ...)
// becomes (A1 & B1) | (A1 & B2) | ... | (A2 & B1) | ...
for left_conj in &self.disjuncts {
for right_conj in &other.disjuncts {
let new_disjuncts = self
.disjuncts
.iter()
.cartesian_product(other.disjuncts.iter())
.map(|(left_conj, right_conj)| {
if right_conj.typeguard.is_some() {
// If the right conjunct has a TypeGuard, it "wins" the conjunction
new_disjuncts.push(*right_conj);
*right_conj
} else {
// Intersect the regular constraints
let new_regular = IntersectionBuilder::new(db)
@ -384,13 +381,13 @@ impl<'db> NarrowingConstraint<'db> {
.add_positive(right_conj.constraint)
.build();
new_disjuncts.push(Conjunction {
Conjunction {
constraint: new_regular,
typeguard: left_conj.typeguard,
});
}
}
}
}
})
.collect::<SmallVec<_>>();
NarrowingConstraint {
disjuncts: new_disjuncts,
@ -416,59 +413,7 @@ impl<'db> From<Type<'db>> for NarrowingConstraint<'db> {
}
}
/// Internal representation of constraints with DNF structure for tracking `TypeGuard` semantics.
///
/// This is a newtype wrapper around `FxHashMap<ScopedPlaceId, NarrowingConstraint<'db>>` that
/// provides methods for working with constraints during boolean operation evaluation.
#[derive(Clone, Debug, Default, PartialEq, Eq)]
struct InternalConstraints<'db> {
constraints: FxHashMap<ScopedPlaceId, NarrowingConstraint<'db>>,
}
impl get_size2::GetSize for InternalConstraints<'_> {}
// SAFETY: InternalConstraints contains only `'db` lifetimes which are covariant,
// and the inner types (FxHashMap, ScopedPlaceId, NarrowingConstraint) are all safe to transmute
unsafe impl salsa::Update for InternalConstraints<'_> {
unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
let old_ref = unsafe { &mut (*old_pointer) };
if *old_ref != new_value {
*old_ref = new_value;
true
} else {
false
}
}
}
impl<'db> InternalConstraints<'db> {
/// Insert a regular (non-`TypeGuard`) constraint for a place
fn insert_regular(&mut self, place: ScopedPlaceId, ty: Type<'db>) {
self.constraints
.insert(place, NarrowingConstraint::regular(ty));
}
/// Convert internal constraints to public constraints by evaluating each DNF constraint to a Type
fn evaluate_type_constraints(self, db: &'db dyn Db) -> NarrowingConstraints<'db> {
self.constraints
.into_iter()
.map(|(place, constraint)| (place, constraint.evaluate_type_constraint(db)))
.collect()
}
}
impl<'db> FromIterator<(ScopedPlaceId, NarrowingConstraint<'db>)> for InternalConstraints<'db> {
fn from_iter<T: IntoIterator<Item = (ScopedPlaceId, NarrowingConstraint<'db>)>>(
iter: T,
) -> Self {
Self {
constraints: FxHashMap::from_iter(iter),
}
}
}
/// Public representation of constraints as returned by tracked functions
type NarrowingConstraints<'db> = FxHashMap<ScopedPlaceId, Type<'db>>;
type NarrowingConstraints<'db> = FxHashMap<ScopedPlaceId, NarrowingConstraint<'db>>;
/// Merge constraints with AND semantics (intersection/conjunction).
///
@ -480,16 +425,16 @@ type NarrowingConstraints<'db> = FxHashMap<ScopedPlaceId, Type<'db>>;
/// - Take the right conjunct if it has a `TypeGuard`
/// - Intersect the constraints normally otherwise
fn merge_constraints_and<'db>(
into: &mut InternalConstraints<'db>,
from: &InternalConstraints<'db>,
into: &mut NarrowingConstraints<'db>,
from: &NarrowingConstraints<'db>,
db: &'db dyn Db,
) {
for (key, from_constraint) in &from.constraints {
match into.constraints.entry(*key) {
for (key, from_constraint) in from {
match into.entry(*key) {
Entry::Occupied(mut entry) => {
let into_constraint = entry.get();
entry.insert(into_constraint.merge_constraint_and(&from_constraint, db));
entry.insert(into_constraint.merge_constraint_and(from_constraint, db));
}
Entry::Vacant(entry) => {
entry.insert(from_constraint.clone());
@ -506,19 +451,15 @@ fn merge_constraints_and<'db>(
/// However, if a place appears in only one branch of the OR, we need to widen it
/// to `object` in the overall result (because the other branch doesn't constrain it).
fn merge_constraints_or<'db>(
into: &mut InternalConstraints<'db>,
from: &InternalConstraints<'db>,
into: &mut NarrowingConstraints<'db>,
from: &NarrowingConstraints<'db>,
_db: &'db dyn Db,
) {
// For places that appear in `into` but not in `from`, widen to object
for (key, value) in &mut into.constraints {
if !from.constraints.contains_key(key) {
*value = NarrowingConstraint::regular(Type::object());
}
}
into.retain(|key, _| from.contains_key(key));
for (key, from_constraint) in &from.constraints {
match into.constraints.entry(*key) {
for (key, from_constraint) in from {
match into.entry(*key) {
Entry::Occupied(mut entry) => {
// Simply concatenate the disjuncts
entry
@ -595,8 +536,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
}
}
fn finish(mut self) -> Option<InternalConstraints<'db>> {
let mut constraints: Option<InternalConstraints<'db>> = match self.predicate {
fn finish(mut self) -> Option<NarrowingConstraints<'db>> {
let mut constraints: Option<NarrowingConstraints<'db>> = match self.predicate {
PredicateNode::Expression(expression) => {
self.evaluate_expression_predicate(expression, self.is_positive)
}
@ -608,7 +549,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
};
if let Some(ref mut constraints) = constraints {
constraints.constraints.shrink_to_fit();
constraints.shrink_to_fit();
}
constraints
@ -618,7 +559,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&mut self,
expression: Expression<'db>,
is_positive: bool,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
let expression_node = expression.node_ref(self.db, self.module);
self.evaluate_expression_node_predicate(expression_node, expression, is_positive)
}
@ -628,7 +569,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
expression_node: &ruff_python_ast::Expr,
expression: Expression<'db>,
is_positive: bool,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
match expression_node {
ast::Expr::Name(_) | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) => {
self.evaluate_simple_expr(expression_node, is_positive)
@ -653,7 +594,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
pattern_predicate_kind: &PatternPredicateKind<'db>,
subject: Expression<'db>,
is_positive: bool,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
match pattern_predicate_kind {
PatternPredicateKind::Singleton(singleton) => {
self.evaluate_match_pattern_singleton(subject, *singleton, is_positive)
@ -678,7 +619,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&mut self,
pattern: PatternPredicate<'db>,
is_positive: bool,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
self.evaluate_pattern_predicate_kind(
pattern.kind(self.db),
pattern.subject(self.db),
@ -788,7 +729,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&mut self,
expr: &ast::Expr,
is_positive: bool,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
let target = place_expr(expr)?;
let place = self.expect_place(&target);
@ -798,7 +739,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
Type::AlwaysTruthy.negate(self.db)
};
Some(InternalConstraints::from_iter([(
Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::regular(ty),
)]))
@ -808,7 +749,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&mut self,
expr_named: &ast::ExprNamed,
is_positive: bool,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
self.evaluate_simple_expr(&expr_named.target, is_positive)
}
@ -1052,7 +993,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
expr_compare: &ast::ExprCompare,
expression: Expression<'db>,
is_positive: bool,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool {
matches!(
expr,
@ -1094,7 +1035,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let comparator_tuples = std::iter::once(&**left)
.chain(comparators)
.tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>();
let mut constraints = InternalConstraints::default();
let mut constraints = NarrowingConstraints::default();
let mut last_rhs_ty: Option<Type> = None;
@ -1113,7 +1054,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive)
{
let place = self.expect_place(&left);
constraints.insert_regular(place, ty);
constraints.insert(place, NarrowingConstraint::regular(ty));
}
}
ast::Expr::Call(ast::ExprCall {
@ -1159,10 +1100,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
.is_some_and(|c| c.is_known(self.db, KnownClass::Type))
{
let place = self.expect_place(&target);
constraints.insert_regular(
constraints.insert(
place,
Type::instance(self.db, rhs_class.unknown_specialization(self.db))
.negate_if(self.db, !is_positive),
NarrowingConstraint::regular(
Type::instance(self.db, rhs_class.unknown_specialization(self.db))
.negate_if(self.db, !is_positive),
),
);
}
}
@ -1177,7 +1120,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
expr_call: &ast::ExprCall,
expression: Expression<'db>,
is_positive: bool,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
let inference = infer_expression_types(self.db, expression, TypeContext::default());
let callable_ty = inference.expression_type(&*expr_call.func);
@ -1214,7 +1157,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
_ => None,
}?;
Some(InternalConstraints::from_iter([place_and_constraint]))
Some(NarrowingConstraints::from_iter([place_and_constraint]))
}
// For the expression `len(E)`, we narrow the type based on whether len(E) is truthy
// (i.e., whether E is non-empty). We only narrow the parts of the type where we know
@ -1232,7 +1175,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
if let Some(narrowed_ty) = Self::narrow_type_by_len(self.db, arg_ty, is_positive) {
let target = place_expr(arg)?;
let place = self.expect_place(&target);
Some(InternalConstraints::from_iter([(
Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::regular(narrowed_ty),
)]))
@ -1263,7 +1206,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let constraint =
Type::protocol_with_readonly_members(self.db, [(attr, Type::object())]);
return Some(InternalConstraints::from_iter([(
return Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::regular(constraint.negate_if(self.db, !is_positive)),
)]));
@ -1276,7 +1219,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
function
.generate_constraint(self.db, class_info_ty)
.map(|constraint| {
InternalConstraints::from_iter([(
NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::regular(
constraint.negate_if(self.db, !is_positive),
@ -1305,7 +1248,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
subject: Expression<'db>,
singleton: ast::Singleton,
is_positive: bool,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);
@ -1315,7 +1258,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
ast::Singleton::False => Type::BooleanLiteral(false),
};
let ty = ty.negate_if(self.db, !is_positive);
Some(InternalConstraints::from_iter([(
Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::regular(ty),
)]))
@ -1327,7 +1270,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
cls: Expression<'db>,
kind: ClassPatternKind,
is_positive: bool,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
if !kind.is_irrefutable() && !is_positive {
// A class pattern like `case Point(x=0, y=0)` is not irrefutable. In the positive case,
// we can still narrow the type of the match subject to `Point`. But in the negative case,
@ -1351,7 +1294,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
_ => return None,
};
Some(InternalConstraints::from_iter([(
Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::regular(narrowed_type),
)]))
@ -1362,7 +1305,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
subject: Expression<'db>,
value: Expression<'db>,
is_positive: bool,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
let place = {
let subject = place_expr(subject.node_ref(self.db, self.module))?;
self.expect_place(&subject)
@ -1374,7 +1317,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module);
self.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive)
.map(|ty| InternalConstraints::from_iter([(place, NarrowingConstraint::regular(ty))]))
.map(|ty| NarrowingConstraints::from_iter([(place, NarrowingConstraint::regular(ty))]))
}
fn evaluate_match_pattern_or(
@ -1382,7 +1325,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
subject: Expression<'db>,
predicates: &Vec<PatternPredicateKind<'db>>,
is_positive: bool,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
let db = self.db;
// DeMorgan's law---if the overall `or` is negated, we need to `and` the negated sub-constraints.
@ -1408,7 +1351,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
expr_bool_op: &ExprBoolOp,
expression: Expression<'db>,
is_positive: bool,
) -> Option<InternalConstraints<'db>> {
) -> Option<NarrowingConstraints<'db>> {
let inference = infer_expression_types(self.db, expression, TypeContext::default());
let sub_constraints = expr_bool_op
.values
@ -1427,7 +1370,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
.collect::<Vec<_>>();
match (expr_bool_op.op, is_positive) {
(BoolOp::And, true) | (BoolOp::Or, false) => {
let mut aggregation: Option<InternalConstraints> = None;
let mut aggregation: Option<NarrowingConstraints> = None;
for sub_constraint in sub_constraints.into_iter().flatten() {
if let Some(ref mut some_aggregation) = aggregation {
merge_constraints_and(some_aggregation, &sub_constraint, self.db);