From ea1aa9ebfec655a07c1dffb7907fd77a4a97c8c4 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Mon, 11 Aug 2025 15:42:53 -0700 Subject: [PATCH] [ty] use interior mutability in type visitors (#19871) ## Summary Type visitors are conceptually immutable, they just internally track the types they've seen (and some maintain a cache of results.) Passing around mutable visitors everywhere can get us into borrow-checker trouble in some cases, where we need to recursively pass along the visitor inside more than one closure with non-disjoint lifetime. Use interior mutability (via `RefCell` and `Cell`) inside the visitors instead, to allow us to pass around shared references. ## Test Plan Existing tests. --- crates/ty_python_semantic/src/types.rs | 119 +++++++----------- crates/ty_python_semantic/src/types/class.rs | 14 +-- .../src/types/class_base.rs | 6 +- crates/ty_python_semantic/src/types/cyclic.rs | 19 +-- .../ty_python_semantic/src/types/function.rs | 15 +-- .../ty_python_semantic/src/types/generics.rs | 20 +-- .../ty_python_semantic/src/types/instance.rs | 19 ++- .../src/types/protocol_class.rs | 18 ++- .../src/types/signatures.rs | 20 +-- .../src/types/subclass_of.rs | 14 +-- crates/ty_python_semantic/src/types/tuple.rs | 22 ++-- .../ty_python_semantic/src/types/visitor.rs | 85 +++++-------- 12 files changed, 131 insertions(+), 240 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 805e41be7a..0b1ba323c2 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -395,7 +395,7 @@ pub struct PropertyInstanceType<'db> { fn walk_property_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, property: PropertyInstanceType<'db>, - visitor: &mut V, + visitor: &V, ) { if let Some(getter) = property.getter(db) { visitor.visit_type(db, getter); @@ -419,7 +419,7 @@ impl<'db> PropertyInstanceType<'db> { Self::new(db, getter, setter) } - fn normalized_impl(self, db: &'db dyn Db, visitor: &mut TypeTransformer<'db>) -> Self { + fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { Self::new( db, self.getter(db).map(|ty| ty.normalized_impl(db, visitor)), @@ -1055,16 +1055,11 @@ impl<'db> Type<'db> { /// - Converts class-based protocols into synthesized protocols #[must_use] pub fn normalized(self, db: &'db dyn Db) -> Self { - let mut visitor = TypeTransformer::default(); - self.normalized_impl(db, &mut visitor) + self.normalized_impl(db, &TypeTransformer::default()) } #[must_use] - pub(crate) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { match self { Type::Union(union) => { visitor.visit(self, |v| Type::Union(union.normalized_impl(db, v))) @@ -1764,21 +1759,20 @@ impl<'db> Type<'db> { /// Note: This function aims to have no false positives, but might return /// wrong `false` answers in some cases. pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool { - let mut visitor = PairVisitor::new(false); - self.is_disjoint_from_impl(db, other, &mut visitor) + self.is_disjoint_from_impl(db, other, &PairVisitor::new(false)) } pub(crate) fn is_disjoint_from_impl( self, db: &'db dyn Db, other: Type<'db>, - visitor: &mut PairVisitor<'db>, + visitor: &PairVisitor<'db>, ) -> bool { fn any_protocol_members_absent_or_disjoint<'db>( db: &'db dyn Db, protocol: ProtocolInstanceType<'db>, other: Type<'db>, - visitor: &mut PairVisitor<'db>, + visitor: &PairVisitor<'db>, ) -> bool { protocol.interface(db).members(db).any(|member| { other @@ -6102,7 +6096,7 @@ pub enum TypeMapping<'a, 'db> { fn walk_type_mapping<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, mapping: &TypeMapping<'_, 'db>, - visitor: &mut V, + visitor: &V, ) { match mapping { TypeMapping::Specialization(specialization) => { @@ -6131,7 +6125,7 @@ impl<'db> TypeMapping<'_, 'db> { } } - fn normalized_impl(&self, db: &'db dyn Db, visitor: &mut TypeTransformer<'db>) -> Self { + fn normalized_impl(&self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { match self { TypeMapping::Specialization(specialization) => { TypeMapping::Specialization(specialization.normalized_impl(db, visitor)) @@ -6194,7 +6188,7 @@ pub enum KnownInstanceType<'db> { fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, known_instance: KnownInstanceType<'db>, - visitor: &mut V, + visitor: &V, ) { match known_instance { KnownInstanceType::SubscriptedProtocol(context) @@ -6217,7 +6211,7 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( } impl<'db> KnownInstanceType<'db> { - fn normalized_impl(self, db: &'db dyn Db, visitor: &mut TypeTransformer<'db>) -> Self { + fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { match self { Self::SubscriptedProtocol(context) => { Self::SubscriptedProtocol(context.normalized_impl(db, visitor)) @@ -6643,11 +6637,7 @@ pub struct FieldInstance<'db> { impl get_size2::GetSize for FieldInstance<'_> {} impl<'db> FieldInstance<'db> { - pub(crate) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { FieldInstance::new( db, self.default_type(db).normalized_impl(db, visitor), @@ -6732,7 +6722,7 @@ impl get_size2::GetSize for TypeVarInstance<'_> {} fn walk_type_var_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, typevar: TypeVarInstance<'db>, - visitor: &mut V, + visitor: &V, ) { if let Some(bounds) = typevar.bound_or_constraints(db) { walk_type_var_bounds(db, bounds, visitor); @@ -6771,11 +6761,7 @@ impl<'db> TypeVarInstance<'db> { } } - pub(crate) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { Self::new( db, self.name(db), @@ -6836,7 +6822,7 @@ impl get_size2::GetSize for BoundTypeVarInstance<'_> {} fn walk_bound_type_var_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, bound_typevar: BoundTypeVarInstance<'db>, - visitor: &mut V, + visitor: &V, ) { visitor.visit_type_var_type(db, bound_typevar.typevar(db)); } @@ -6871,11 +6857,7 @@ impl<'db> BoundTypeVarInstance<'db> { .map(|ty| ty.apply_type_mapping(db, &TypeMapping::BindLegacyTypevars(binding_context))) } - pub(crate) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { Self::new( db, self.typevar(db).normalized_impl(db, visitor), @@ -6923,7 +6905,7 @@ pub enum TypeVarBoundOrConstraints<'db> { fn walk_type_var_bounds<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, bounds: TypeVarBoundOrConstraints<'db>, - visitor: &mut V, + visitor: &V, ) { match bounds { TypeVarBoundOrConstraints::UpperBound(bound) => visitor.visit_type(db, bound), @@ -6934,7 +6916,7 @@ fn walk_type_var_bounds<'db, V: visitor::TypeVisitor<'db> + ?Sized>( } impl<'db> TypeVarBoundOrConstraints<'db> { - fn normalized_impl(self, db: &'db dyn Db, visitor: &mut TypeTransformer<'db>) -> Self { + fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { match self { TypeVarBoundOrConstraints::UpperBound(bound) => { TypeVarBoundOrConstraints::UpperBound(bound.normalized_impl(db, visitor)) @@ -7951,7 +7933,7 @@ impl get_size2::GetSize for BoundMethodType<'_> {} fn walk_bound_method_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, method: BoundMethodType<'db>, - visitor: &mut V, + visitor: &V, ) { visitor.visit_function_type(db, method.function(db)); visitor.visit_type(db, method.self_instance(db)); @@ -7972,7 +7954,7 @@ impl<'db> BoundMethodType<'db> { ) } - fn normalized_impl(self, db: &'db dyn Db, visitor: &mut TypeTransformer<'db>) -> Self { + fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { Self::new( db, self.function(db).normalized_impl(db, visitor), @@ -8025,7 +8007,7 @@ pub struct CallableType<'db> { pub(super) fn walk_callable_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, ty: CallableType<'db>, - visitor: &mut V, + visitor: &V, ) { for signature in &ty.signatures(db).overloads { walk_signature(db, signature, visitor); @@ -8089,7 +8071,7 @@ impl<'db> CallableType<'db> { /// Return a "normalized" version of this `Callable` type. /// /// See [`Type::normalized`] for more details. - fn normalized_impl(self, db: &'db dyn Db, visitor: &mut TypeTransformer<'db>) -> Self { + fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { CallableType::new( db, self.signatures(db).normalized_impl(db, visitor), @@ -8165,7 +8147,7 @@ pub enum MethodWrapperKind<'db> { pub(super) fn walk_method_wrapper_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, method_wrapper: MethodWrapperKind<'db>, - visitor: &mut V, + visitor: &V, ) { match method_wrapper { MethodWrapperKind::FunctionTypeDunderGet(function) => { @@ -8253,7 +8235,7 @@ impl<'db> MethodWrapperKind<'db> { } } - fn normalized_impl(self, db: &'db dyn Db, visitor: &mut TypeTransformer<'db>) -> Self { + fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { match self { MethodWrapperKind::FunctionTypeDunderGet(function) => { MethodWrapperKind::FunctionTypeDunderGet(function.normalized_impl(db, visitor)) @@ -8413,7 +8395,7 @@ impl get_size2::GetSize for PEP695TypeAliasType<'_> {} fn walk_pep_695_type_alias<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, type_alias: PEP695TypeAliasType<'db>, - visitor: &mut V, + visitor: &V, ) { visitor.visit_type(db, type_alias.value_type(db)); } @@ -8437,7 +8419,7 @@ impl<'db> PEP695TypeAliasType<'db> { definition_expression_type(db, definition, &type_alias_stmt_node.value) } - fn normalized_impl(self, _db: &'db dyn Db, _visitor: &mut TypeTransformer<'db>) -> Self { + fn normalized_impl(self, _db: &'db dyn Db, _visitor: &TypeTransformer<'db>) -> Self { self } } @@ -8460,13 +8442,13 @@ impl get_size2::GetSize for BareTypeAliasType<'_> {} fn walk_bare_type_alias<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, type_alias: BareTypeAliasType<'db>, - visitor: &mut V, + visitor: &V, ) { visitor.visit_type(db, type_alias.value(db)); } impl<'db> BareTypeAliasType<'db> { - fn normalized_impl(self, db: &'db dyn Db, visitor: &mut TypeTransformer<'db>) -> Self { + fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { Self::new( db, self.name(db), @@ -8487,7 +8469,7 @@ pub enum TypeAliasType<'db> { fn walk_type_alias_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, type_alias: TypeAliasType<'db>, - visitor: &mut V, + visitor: &V, ) { match type_alias { TypeAliasType::PEP695(type_alias) => { @@ -8500,11 +8482,7 @@ fn walk_type_alias_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( } impl<'db> TypeAliasType<'db> { - pub(crate) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { match self { TypeAliasType::PEP695(type_alias) => { TypeAliasType::PEP695(type_alias.normalized_impl(db, visitor)) @@ -8554,7 +8532,7 @@ pub struct UnionType<'db> { pub(crate) fn walk_union<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, union: UnionType<'db>, - visitor: &mut V, + visitor: &V, ) { for element in union.elements(db) { visitor.visit_type(db, *element); @@ -8730,14 +8708,10 @@ impl<'db> UnionType<'db> { /// See [`Type::normalized`] for more details. #[must_use] pub(crate) fn normalized(self, db: &'db dyn Db) -> Self { - self.normalized_impl(db, &mut TypeTransformer::default()) + self.normalized_impl(db, &TypeTransformer::default()) } - pub(crate) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { let mut new_elements: Vec> = self .elements(db) .iter() @@ -8791,7 +8765,7 @@ impl get_size2::GetSize for IntersectionType<'_> {} pub(super) fn walk_intersection_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, intersection: IntersectionType<'db>, - visitor: &mut V, + visitor: &V, ) { for element in intersection.positive(db) { visitor.visit_type(db, *element); @@ -8808,19 +8782,14 @@ impl<'db> IntersectionType<'db> { /// See [`Type::normalized`] for more details. #[must_use] pub(crate) fn normalized(self, db: &'db dyn Db) -> Self { - let mut visitor = TypeTransformer::default(); - self.normalized_impl(db, &mut visitor) + self.normalized_impl(db, &TypeTransformer::default()) } - pub(crate) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { fn normalized_set<'db>( db: &'db dyn Db, elements: &FxOrderSet>, - visitor: &mut TypeTransformer<'db>, + visitor: &TypeTransformer<'db>, ) -> FxOrderSet> { let mut elements: FxOrderSet> = elements .iter() @@ -9084,7 +9053,7 @@ impl<'db> TypedDictType<'db> { fn walk_typed_dict_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, typed_dict: TypedDictType<'db>, - visitor: &mut V, + visitor: &V, ) { visitor.visit_type(db, typed_dict.defining_class(db).into()); } @@ -9143,7 +9112,7 @@ pub enum SuperOwnerKind<'db> { } impl<'db> SuperOwnerKind<'db> { - fn normalized_impl(self, db: &'db dyn Db, visitor: &mut TypeTransformer<'db>) -> Self { + fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { match self { SuperOwnerKind::Dynamic(dynamic) => SuperOwnerKind::Dynamic(dynamic.normalized()), SuperOwnerKind::Class(class) => { @@ -9236,7 +9205,7 @@ impl get_size2::GetSize for BoundSuperType<'_> {} fn walk_bound_super_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, bound_super: BoundSuperType<'db>, - visitor: &mut V, + visitor: &V, ) { visitor.visit_type(db, bound_super.pivot_class(db).into()); visitor.visit_type(db, bound_super.owner(db).into_type()); @@ -9403,11 +9372,7 @@ impl<'db> BoundSuperType<'db> { } } - pub(super) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { Self::new( db, self.pivot_class(db).normalized_impl(db, visitor), @@ -9427,7 +9392,7 @@ pub struct TypeIsType<'db> { fn walk_typeis_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, typeis_type: TypeIsType<'db>, - visitor: &mut V, + visitor: &V, ) { visitor.visit_type(db, typeis_type.return_type(db)); } diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index f04e78da56..6de3a5fa5e 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -222,7 +222,7 @@ pub struct GenericAlias<'db> { pub(super) fn walk_generic_alias<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, alias: GenericAlias<'db>, - visitor: &mut V, + visitor: &V, ) { walk_specialization(db, alias.specialization(db), visitor); } @@ -231,11 +231,7 @@ pub(super) fn walk_generic_alias<'db, V: super::visitor::TypeVisitor<'db> + ?Siz impl get_size2::GetSize for GenericAlias<'_> {} impl<'db> GenericAlias<'db> { - pub(super) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { Self::new( db, self.origin(db), @@ -321,11 +317,7 @@ impl<'db> ClassType<'db> { } } - pub(super) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { match self { Self::NonGeneric(_) => self, Self::Generic(generic) => Self::Generic(generic.normalized_impl(db, visitor)), diff --git a/crates/ty_python_semantic/src/types/class_base.rs b/crates/ty_python_semantic/src/types/class_base.rs index dfc90a1867..d381c94110 100644 --- a/crates/ty_python_semantic/src/types/class_base.rs +++ b/crates/ty_python_semantic/src/types/class_base.rs @@ -32,11 +32,7 @@ impl<'db> ClassBase<'db> { Self::Dynamic(DynamicType::Unknown) } - pub(crate) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { match self { Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized()), Self::Class(class) => Self::Class(class.normalized_impl(db, visitor)), diff --git a/crates/ty_python_semantic/src/types/cyclic.rs b/crates/ty_python_semantic/src/types/cyclic.rs index e0395ad86d..3d89e3c140 100644 --- a/crates/ty_python_semantic/src/types/cyclic.rs +++ b/crates/ty_python_semantic/src/types/cyclic.rs @@ -2,6 +2,7 @@ use rustc_hash::FxHashMap; use crate::FxIndexSet; use crate::types::Type; +use std::cell::RefCell; use std::cmp::Eq; use std::hash::Hash; @@ -25,14 +26,14 @@ pub(crate) struct CycleDetector { /// it indicates that we've hit a cycle (due to a recursive type); /// we need to immediately short circuit the whole operation and return the fallback value. /// That's why we pop items off the end of `seen` after we've visited them. - seen: FxIndexSet, + seen: RefCell>, /// Unlike `seen`, this field is a pure performance optimisation (and an essential one). /// If the type we're trying to normalize is present in `cache`, it doesn't necessarily mean we've hit a cycle: /// it just means that we've already visited this inner type as part of a bigger call chain we're currently in. /// Since this cache is just a performance optimisation, it doesn't make sense to pop items off the end of the /// cache after they've been visited (it would sort-of defeat the point of a cache if we did!) - cache: FxHashMap, + cache: RefCell>, fallback: R, } @@ -40,25 +41,25 @@ pub(crate) struct CycleDetector { impl CycleDetector { pub(crate) fn new(fallback: R) -> Self { CycleDetector { - seen: FxIndexSet::default(), - cache: FxHashMap::default(), + seen: RefCell::new(FxIndexSet::default()), + cache: RefCell::new(FxHashMap::default()), fallback, } } - pub(crate) fn visit(&mut self, item: T, func: impl FnOnce(&mut Self) -> R) -> R { - if let Some(ty) = self.cache.get(&item) { + pub(crate) fn visit(&self, item: T, func: impl FnOnce(&Self) -> R) -> R { + if let Some(ty) = self.cache.borrow().get(&item) { return *ty; } // We hit a cycle - if !self.seen.insert(item) { + if !self.seen.borrow_mut().insert(item) { return self.fallback; } let ret = func(self); - self.seen.pop(); - self.cache.insert(item, ret); + self.seen.borrow_mut().pop(); + self.cache.borrow_mut().insert(item, ret); ret } diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index c1434f6cb0..9be755e333 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -436,7 +436,7 @@ pub struct FunctionLiteral<'db> { fn walk_function_literal<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, function: FunctionLiteral<'db>, - visitor: &mut V, + visitor: &V, ) { if let Some(context) = function.inherited_generic_context(db) { walk_generic_context(db, context, visitor); @@ -599,7 +599,7 @@ impl<'db> FunctionLiteral<'db> { ) } - fn normalized_impl(self, db: &'db dyn Db, visitor: &mut TypeTransformer<'db>) -> Self { + fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { let context = self .inherited_generic_context(db) .map(|ctx| ctx.normalized_impl(db, visitor)); @@ -627,7 +627,7 @@ impl get_size2::GetSize for FunctionType<'_> {} pub(super) fn walk_function_type<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, function: FunctionType<'db>, - visitor: &mut V, + visitor: &V, ) { walk_function_literal(db, function.literal(db), visitor); for mapping in function.type_mappings(db) { @@ -915,15 +915,10 @@ impl<'db> FunctionType<'db> { } pub(crate) fn normalized(self, db: &'db dyn Db) -> Self { - let mut visitor = TypeTransformer::default(); - self.normalized_impl(db, &mut visitor) + self.normalized_impl(db, &TypeTransformer::default()) } - pub(crate) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { let mappings: Box<_> = self .type_mappings(db) .iter() diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 20851a9f69..07dedb7041 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -99,7 +99,7 @@ pub struct GenericContext<'db> { pub(super) fn walk_generic_context<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, context: GenericContext<'db>, - visitor: &mut V, + visitor: &V, ) { for bound_typevar in context.variables(db) { visitor.visit_bound_type_var_type(db, *bound_typevar); @@ -355,11 +355,7 @@ impl<'db> GenericContext<'db> { Specialization::new(db, self, expanded.into_boxed_slice(), None) } - pub(crate) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { let variables: FxOrderSet<_> = self .variables(db) .iter() @@ -408,7 +404,7 @@ pub struct Specialization<'db> { pub(super) fn walk_specialization<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, specialization: Specialization<'db>, - visitor: &mut V, + visitor: &V, ) { walk_generic_context(db, specialization.generic_context(db), visitor); for ty in specialization.types(db) { @@ -510,11 +506,7 @@ impl<'db> Specialization<'db> { Specialization::new(db, self.generic_context(db), types, None) } - pub(crate) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { let types: Box<[_]> = self .types(db) .iter() @@ -672,7 +664,7 @@ pub struct PartialSpecialization<'a, 'db> { pub(super) fn walk_partial_specialization<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, specialization: &PartialSpecialization<'_, 'db>, - visitor: &mut V, + visitor: &V, ) { walk_generic_context(db, specialization.generic_context, visitor); for ty in &*specialization.types { @@ -705,7 +697,7 @@ impl<'db> PartialSpecialization<'_, 'db> { pub(crate) fn normalized_impl( &self, db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, + visitor: &TypeTransformer<'db>, ) -> PartialSpecialization<'db, 'db> { let generic_context = self.generic_context.normalized_impl(db, visitor); let types: Cow<_> = self diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 5e61cd3f13..7a7bc1c6ae 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -68,7 +68,7 @@ impl<'db> Type<'db> { SynthesizedProtocolType::new( db, ProtocolInterface::with_property_members(db, members), - &mut TypeTransformer::default(), + &TypeTransformer::default(), ), )) } @@ -99,7 +99,7 @@ pub struct NominalInstanceType<'db>( pub(super) fn walk_nominal_instance_type<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, nominal: NominalInstanceType<'db>, - visitor: &mut V, + visitor: &V, ) { visitor.visit_type(db, nominal.class(db).into()); } @@ -241,7 +241,7 @@ impl<'db> NominalInstanceType<'db> { pub(super) fn normalized_impl( self, db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, + visitor: &TypeTransformer<'db>, ) -> Type<'db> { match self.0 { NominalInstanceInner::ExactTuple(tuple) => { @@ -296,7 +296,7 @@ impl<'db> NominalInstanceType<'db> { self, db: &'db dyn Db, other: Self, - visitor: &mut PairVisitor<'db>, + visitor: &PairVisitor<'db>, ) -> bool { let self_spec = self.tuple_spec(db); if let Some(self_spec) = self_spec.as_deref() { @@ -421,7 +421,7 @@ pub struct ProtocolInstanceType<'db> { pub(super) fn walk_protocol_instance_type<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, protocol: ProtocolInstanceType<'db>, - visitor: &mut V, + visitor: &V, ) { walk_protocol_interface(db, protocol.inner.interface(db), visitor); } @@ -471,8 +471,7 @@ impl<'db> ProtocolInstanceType<'db> { /// /// See [`Type::normalized`] for more details. pub(super) fn normalized(self, db: &'db dyn Db) -> Type<'db> { - let mut visitor = TypeTransformer::default(); - self.normalized_impl(db, &mut visitor) + self.normalized_impl(db, &TypeTransformer::default()) } /// Return a "normalized" version of this `Protocol` type. @@ -481,7 +480,7 @@ impl<'db> ProtocolInstanceType<'db> { pub(super) fn normalized_impl( self, db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, + visitor: &TypeTransformer<'db>, ) -> Type<'db> { let object = Type::object(db); if object.satisfies_protocol(db, self, TypeRelation::Subtyping) { @@ -533,7 +532,7 @@ impl<'db> ProtocolInstanceType<'db> { self, _db: &'db dyn Db, _other: Self, - _visitor: &mut PairVisitor<'db>, + _visitor: &PairVisitor<'db>, ) -> bool { false } @@ -640,7 +639,7 @@ mod synthesized_protocol { pub(super) fn new( db: &'db dyn Db, interface: ProtocolInterface<'db>, - visitor: &mut TypeTransformer<'db>, + visitor: &TypeTransformer<'db>, ) -> Self { Self(interface.normalized_impl(db, visitor)) } diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index c262055e6c..f544e35103 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -82,7 +82,7 @@ impl get_size2::GetSize for ProtocolInterface<'_> {} pub(super) fn walk_protocol_interface<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, interface: ProtocolInterface<'db>, - visitor: &mut V, + visitor: &V, ) { for member in interface.members(db) { walk_protocol_member(db, &member, visitor); @@ -165,11 +165,7 @@ impl<'db> ProtocolInterface<'db> { .all(|member_name| other.inner(db).contains_key(member_name)) } - pub(super) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { Self::new( db, self.inner(db) @@ -253,10 +249,10 @@ pub(super) struct ProtocolMemberData<'db> { impl<'db> ProtocolMemberData<'db> { fn normalized(&self, db: &'db dyn Db) -> Self { - self.normalized_impl(db, &mut TypeTransformer::default()) + self.normalized_impl(db, &TypeTransformer::default()) } - fn normalized_impl(&self, db: &'db dyn Db, visitor: &mut TypeTransformer<'db>) -> Self { + fn normalized_impl(&self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { Self { kind: self.kind.normalized_impl(db, visitor), qualifiers: self.qualifiers, @@ -331,7 +327,7 @@ enum ProtocolMemberKind<'db> { } impl<'db> ProtocolMemberKind<'db> { - fn normalized_impl(&self, db: &'db dyn Db, visitor: &mut TypeTransformer<'db>) -> Self { + fn normalized_impl(&self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { match self { ProtocolMemberKind::Method(callable) => { ProtocolMemberKind::Method(callable.normalized_impl(db, visitor)) @@ -404,7 +400,7 @@ pub(super) struct ProtocolMember<'a, 'db> { fn walk_protocol_member<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, member: &ProtocolMember<'_, 'db>, - visitor: &mut V, + visitor: &V, ) { match member.kind { ProtocolMemberKind::Method(method) => visitor.visit_callable_type(db, method), @@ -436,7 +432,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> { &self, db: &'db dyn Db, other: Type<'db>, - visitor: &mut PairVisitor<'db>, + visitor: &PairVisitor<'db>, ) -> bool { match &self.kind { // TODO: implement disjointness for property/method members as well as attribute members diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 9f1fc1a272..dfcc212ce5 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -61,11 +61,7 @@ impl<'db> CallableSignature<'db> { ) } - pub(crate) fn normalized_impl( - &self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(&self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { Self::from_overloads( self.overloads .iter() @@ -245,7 +241,7 @@ pub struct Signature<'db> { pub(super) fn walk_signature<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, signature: &Signature<'db>, - visitor: &mut V, + visitor: &V, ) { if let Some(generic_context) = &signature.generic_context { walk_generic_context(db, *generic_context, visitor); @@ -384,11 +380,7 @@ impl<'db> Signature<'db> { } } - pub(crate) fn normalized_impl( - &self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(&self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { Self { generic_context: self .generic_context @@ -1368,11 +1360,7 @@ impl<'db> Parameter<'db> { /// Normalize nested unions and intersections in the annotated type, if any. /// /// See [`Type::normalized`] for more details. - pub(crate) fn normalized_impl( - &self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(&self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { let Parameter { annotated_type, kind, diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index f5b7701820..fdedce211f 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -20,7 +20,7 @@ pub struct SubclassOfType<'db> { pub(super) fn walk_subclass_of_type<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, subclass_of: SubclassOfType<'db>, - visitor: &mut V, + visitor: &V, ) { visitor.visit_type(db, Type::from(subclass_of.subclass_of)); } @@ -185,11 +185,7 @@ impl<'db> SubclassOfType<'db> { } } - pub(crate) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { Self { subclass_of: self.subclass_of.normalized_impl(db, visitor), } @@ -252,11 +248,7 @@ impl<'db> SubclassOfInner<'db> { } } - pub(crate) fn normalized_impl( - self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { match self { Self::Class(class) => Self::Class(class.normalized_impl(db, visitor)), Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized()), diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index f04346d7c8..121c3a9c55 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -135,7 +135,7 @@ pub struct TupleType<'db> { pub(super) fn walk_tuple_type<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, tuple: TupleType<'db>, - visitor: &mut V, + visitor: &V, ) { for element in tuple.tuple(db).all_elements() { visitor.visit_type(db, *element); @@ -245,7 +245,7 @@ impl<'db> TupleType<'db> { pub(crate) fn normalized_impl( self, db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, + visitor: &TypeTransformer<'db>, ) -> Option { TupleType::new(db, self.tuple(db).normalized_impl(db, visitor)) } @@ -393,7 +393,7 @@ impl<'db> FixedLengthTuple> { } #[must_use] - fn normalized_impl(&self, db: &'db dyn Db, visitor: &mut TypeTransformer<'db>) -> Self { + fn normalized_impl(&self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { Self::from_elements(self.0.iter().map(|ty| ty.normalized_impl(db, visitor))) } @@ -707,11 +707,7 @@ impl<'db> VariableLengthTuple> { } #[must_use] - fn normalized_impl( - &self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> TupleSpec<'db> { + fn normalized_impl(&self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> TupleSpec<'db> { let prefix = self .prenormalized_prefix_elements(db, None) .map(|ty| ty.normalized_impl(db, visitor)) @@ -1057,11 +1053,7 @@ impl<'db> Tuple> { } } - pub(crate) fn normalized_impl( - &self, - db: &'db dyn Db, - visitor: &mut TypeTransformer<'db>, - ) -> Self { + pub(crate) fn normalized_impl(&self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self { match self { Tuple::Fixed(tuple) => Tuple::Fixed(tuple.normalized_impl(db, visitor)), Tuple::Variable(tuple) => tuple.normalized_impl(db, visitor), @@ -1121,7 +1113,7 @@ impl<'db> Tuple> { &self, db: &'db dyn Db, other: &Self, - visitor: &mut PairVisitor<'db>, + visitor: &PairVisitor<'db>, ) -> bool { // Two tuples with an incompatible number of required elements must always be disjoint. let (self_min, self_max) = self.len().size_hint(); @@ -1139,7 +1131,7 @@ impl<'db> Tuple> { db: &'db dyn Db, a: impl IntoIterator>, b: impl IntoIterator>, - visitor: &mut PairVisitor<'db>, + visitor: &PairVisitor<'db>, ) -> bool where 'db: 's, diff --git a/crates/ty_python_semantic/src/types/visitor.rs b/crates/ty_python_semantic/src/types/visitor.rs index 705ecd03d0..2bfd3b92bd 100644 --- a/crates/ty_python_semantic/src/types/visitor.rs +++ b/crates/ty_python_semantic/src/types/visitor.rs @@ -15,6 +15,7 @@ use crate::{ walk_type_var_type, walk_typed_dict_type, walk_typeis_type, walk_union, }, }; +use std::cell::{Cell, RefCell}; /// A visitor trait that recurses into nested types. /// @@ -22,97 +23,77 @@ use crate::{ /// but it makes it easy for implementors of the trait to do so. /// See [`any_over_type`] for an example of how to do this. pub(crate) trait TypeVisitor<'db> { - fn visit_type(&mut self, db: &'db dyn Db, ty: Type<'db>); + fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>); - fn visit_union_type(&mut self, db: &'db dyn Db, union: UnionType<'db>) { + fn visit_union_type(&self, db: &'db dyn Db, union: UnionType<'db>) { walk_union(db, union, self); } - fn visit_intersection_type(&mut self, db: &'db dyn Db, intersection: IntersectionType<'db>) { + fn visit_intersection_type(&self, db: &'db dyn Db, intersection: IntersectionType<'db>) { walk_intersection_type(db, intersection, self); } - fn visit_callable_type(&mut self, db: &'db dyn Db, callable: CallableType<'db>) { + fn visit_callable_type(&self, db: &'db dyn Db, callable: CallableType<'db>) { walk_callable_type(db, callable, self); } - fn visit_property_instance_type( - &mut self, - db: &'db dyn Db, - property: PropertyInstanceType<'db>, - ) { + fn visit_property_instance_type(&self, db: &'db dyn Db, property: PropertyInstanceType<'db>) { walk_property_instance_type(db, property, self); } - fn visit_typeis_type(&mut self, db: &'db dyn Db, type_is: TypeIsType<'db>) { + fn visit_typeis_type(&self, db: &'db dyn Db, type_is: TypeIsType<'db>) { walk_typeis_type(db, type_is, self); } - fn visit_subclass_of_type(&mut self, db: &'db dyn Db, subclass_of: SubclassOfType<'db>) { + fn visit_subclass_of_type(&self, db: &'db dyn Db, subclass_of: SubclassOfType<'db>) { walk_subclass_of_type(db, subclass_of, self); } - fn visit_generic_alias_type(&mut self, db: &'db dyn Db, alias: GenericAlias<'db>) { + fn visit_generic_alias_type(&self, db: &'db dyn Db, alias: GenericAlias<'db>) { walk_generic_alias(db, alias, self); } - fn visit_function_type(&mut self, db: &'db dyn Db, function: FunctionType<'db>) { + fn visit_function_type(&self, db: &'db dyn Db, function: FunctionType<'db>) { walk_function_type(db, function, self); } - fn visit_bound_method_type(&mut self, db: &'db dyn Db, method: BoundMethodType<'db>) { + fn visit_bound_method_type(&self, db: &'db dyn Db, method: BoundMethodType<'db>) { walk_bound_method_type(db, method, self); } - fn visit_bound_super_type(&mut self, db: &'db dyn Db, bound_super: BoundSuperType<'db>) { + fn visit_bound_super_type(&self, db: &'db dyn Db, bound_super: BoundSuperType<'db>) { walk_bound_super_type(db, bound_super, self); } - fn visit_nominal_instance_type(&mut self, db: &'db dyn Db, nominal: NominalInstanceType<'db>) { + fn visit_nominal_instance_type(&self, db: &'db dyn Db, nominal: NominalInstanceType<'db>) { walk_nominal_instance_type(db, nominal, self); } - fn visit_bound_type_var_type( - &mut self, - db: &'db dyn Db, - bound_typevar: BoundTypeVarInstance<'db>, - ) { + fn visit_bound_type_var_type(&self, db: &'db dyn Db, bound_typevar: BoundTypeVarInstance<'db>) { walk_bound_type_var_type(db, bound_typevar, self); } - fn visit_type_var_type(&mut self, db: &'db dyn Db, typevar: TypeVarInstance<'db>) { + fn visit_type_var_type(&self, db: &'db dyn Db, typevar: TypeVarInstance<'db>) { walk_type_var_type(db, typevar, self); } - fn visit_protocol_instance_type( - &mut self, - db: &'db dyn Db, - protocol: ProtocolInstanceType<'db>, - ) { + fn visit_protocol_instance_type(&self, db: &'db dyn Db, protocol: ProtocolInstanceType<'db>) { walk_protocol_instance_type(db, protocol, self); } - fn visit_method_wrapper_type( - &mut self, - db: &'db dyn Db, - method_wrapper: MethodWrapperKind<'db>, - ) { + fn visit_method_wrapper_type(&self, db: &'db dyn Db, method_wrapper: MethodWrapperKind<'db>) { walk_method_wrapper_type(db, method_wrapper, self); } - fn visit_known_instance_type( - &mut self, - db: &'db dyn Db, - known_instance: KnownInstanceType<'db>, - ) { + fn visit_known_instance_type(&self, db: &'db dyn Db, known_instance: KnownInstanceType<'db>) { walk_known_instance_type(db, known_instance, self); } - fn visit_type_alias_type(&mut self, db: &'db dyn Db, type_alias: TypeAliasType<'db>) { + fn visit_type_alias_type(&self, db: &'db dyn Db, type_alias: TypeAliasType<'db>) { walk_type_alias_type(db, type_alias, self); } - fn visit_typed_dict_type(&mut self, db: &'db dyn Db, typed_dict: TypedDictType<'db>) { + fn visit_typed_dict_type(&self, db: &'db dyn Db, typed_dict: TypedDictType<'db>) { walk_typed_dict_type(db, typed_dict, self); } } @@ -209,7 +190,7 @@ impl<'db> From> for TypeKind<'db> { fn walk_non_atomic_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, non_atomic_type: NonAtomicType<'db>, - visitor: &mut V, + visitor: &V, ) { match non_atomic_type { NonAtomicType::FunctionLiteral(function) => visitor.visit_function_type(db, function), @@ -254,23 +235,25 @@ pub(super) fn any_over_type<'db>( ) -> bool { struct AnyOverTypeVisitor<'db, 'a> { query: &'a dyn Fn(Type<'db>) -> bool, - seen_types: FxIndexSet>, - found_matching_type: bool, + seen_types: RefCell>>, + found_matching_type: Cell, } impl<'db> TypeVisitor<'db> for AnyOverTypeVisitor<'db, '_> { - fn visit_type(&mut self, db: &'db dyn Db, ty: Type<'db>) { - if self.found_matching_type { + fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>) { + let already_found = self.found_matching_type.get(); + if already_found { return; } - self.found_matching_type |= (self.query)(ty); - if self.found_matching_type { + let found = already_found | (self.query)(ty); + self.found_matching_type.set(found); + if found { return; } match TypeKind::from(ty) { TypeKind::Atomic => {} TypeKind::NonAtomic(non_atomic_type) => { - if !self.seen_types.insert(non_atomic_type) { + if !self.seen_types.borrow_mut().insert(non_atomic_type) { // If we have already seen this type, we can skip it. return; } @@ -280,11 +263,11 @@ pub(super) fn any_over_type<'db>( } } - let mut visitor = AnyOverTypeVisitor { + let visitor = AnyOverTypeVisitor { query, - seen_types: FxIndexSet::default(), - found_matching_type: false, + seen_types: RefCell::new(FxIndexSet::default()), + found_matching_type: Cell::new(false), }; visitor.visit_type(db, ty); - visitor.found_matching_type + visitor.found_matching_type.get() }