diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index f5cf1b4d38..d5a1373f6b 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -68,7 +68,10 @@ pub(crate) fn infer_narrowing_constraint<'db>( PredicateNode::StarImportPlaceholder(_) => return None, }; if let Some(constraints) = constraints { - constraints.get(&place).copied() + constraints + .constraints + .get(&place) + .map(|constraint| constraint.clone().evaluate_type_constraint(db)) } else { None } @@ -78,7 +81,7 @@ pub(crate) fn infer_narrowing_constraint<'db>( fn all_narrowing_constraints_for_pattern<'db>( db: &'db dyn Db, pattern: PatternPredicate<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, pattern.file(db)).load(db); NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), true).finish() } @@ -91,7 +94,7 @@ fn all_narrowing_constraints_for_pattern<'db>( fn all_narrowing_constraints_for_expression<'db>( db: &'db dyn Db, expression: Expression<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, expression.file(db)).load(db); NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), true) .finish() @@ -105,7 +108,7 @@ fn all_narrowing_constraints_for_expression<'db>( fn all_negative_narrowing_constraints_for_expression<'db>( db: &'db dyn Db, expression: Expression<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, expression.file(db)).load(db); NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), false) .finish() @@ -115,7 +118,7 @@ fn all_negative_narrowing_constraints_for_expression<'db>( fn all_negative_narrowing_constraints_for_pattern<'db>( db: &'db dyn Db, pattern: PatternPredicate<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, pattern.file(db)).load(db); NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), false).finish() } @@ -124,7 +127,7 @@ fn constraints_for_expression_cycle_initial<'db>( _db: &'db dyn Db, _id: salsa::Id, _expression: Expression<'db>, -) -> Option> { +) -> Option> { None } @@ -132,7 +135,7 @@ fn negative_constraints_for_expression_cycle_initial<'db>( _db: &'db dyn Db, _id: salsa::Id, _expression: Expression<'db>, -) -> Option> { +) -> Option> { None } @@ -389,11 +392,27 @@ impl<'db> From> for NarrowingConstraint<'db> { /// /// This is a newtype wrapper around `FxHashMap>` that /// provides methods for working with constraints during boolean operation evaluation. -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, PartialEq, Eq)] struct InternalConstraints<'db> { constraints: FxHashMap>, } +impl get_size2::GetSize for InternalConstraints<'_> {} + +// SAFETY: InternalConstraints contains only `'db` lifetimes which are covariant, +// and the inner types (FxHashMap, ScopedPlaceId, NarrowingConstraint) are all safe to transmute +unsafe impl salsa::Update for InternalConstraints<'_> { + unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { + let old_ref = unsafe { &mut (*old_pointer) }; + if *old_ref != new_value { + *old_ref = new_value; + true + } else { + false + } + } +} + impl<'db> InternalConstraints<'db> { /// Insert a regular (non-`TypeGuard`) constraint for a place fn insert_regular(&mut self, place: ScopedPlaceId, ty: Type<'db>) { @@ -576,8 +595,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { } } - fn finish(mut self) -> Option> { - let constraints: Option> = match self.predicate { + fn finish(mut self) -> Option> { + let mut constraints: Option> = match self.predicate { PredicateNode::Expression(expression) => { self.evaluate_expression_predicate(expression, self.is_positive) } @@ -587,12 +606,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { PredicateNode::ReturnsNever(_) => return None, PredicateNode::StarImportPlaceholder(_) => return None, }; - if let Some(mut constraints) = constraints { + + if let Some(ref mut constraints) = constraints { constraints.constraints.shrink_to_fit(); - Some(constraints.evaluate_type_constraints(self.db)) - } else { - None } + + constraints } fn evaluate_expression_predicate(