diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 7e62695426..3f29cbb4fa 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -53,14 +53,18 @@ //! //! [bdd]: https://en.wikipedia.org/wiki/Binary_decision_diagram +use std::cell::RefCell; use std::cmp::Ordering; use std::fmt::Display; use itertools::Itertools; use rustc_hash::FxHashSet; -use crate::Db; -use crate::types::{BoundTypeVarInstance, IntersectionType, Type, UnionType}; +use crate::types::visitor::{NonAtomicType, TypeKind, TypeVisitor, walk_non_atomic_type}; +use crate::types::{ + BoundTypeVarInstance, IntersectionType, Type, TypeVarBoundOrConstraints, UnionType, +}; +use crate::{Db, FxIndexSet}; /// An extension trait for building constraint sets from [`Option`] values. pub(crate) trait OptionConstraintsExtension { @@ -1370,3 +1374,67 @@ impl<'db> SatisfiedClauses<'db> { DisplaySatisfiedClauses { clauses: self, db } } } + +/// Returns a constraint set describing the valid specializations of a typevar. +impl<'db> BoundTypeVarInstance<'db> { + pub(crate) fn valid_specializations(self, db: &'db dyn Db) -> ConstraintSet<'db> { + match self.typevar(db).bound_or_constraints(db) { + None => ConstraintSet::from(true), + Some(TypeVarBoundOrConstraints::UpperBound(bound)) => { + ConstraintSet::constrain_typevar(db, self, Type::Never, bound) + } + Some(TypeVarBoundOrConstraints::Constraints(constraints)) => { + constraints.elements(db).iter().when_any(db, |constraint| { + ConstraintSet::constrain_typevar(db, self, *constraint, *constraint) + }) + } + } + } +} + +/// Returns a constraint set describing the valid specializations of any typevar mentioned in a +/// type. +impl<'db> Type<'db> { + pub(crate) fn valid_specializations(self, db: &'db dyn Db) -> ConstraintSet<'db> { + struct ValidSpecializationsVisitor<'db> { + seen_types: RefCell>>, + result: RefCell>, + } + + impl<'db> TypeVisitor<'db> for ValidSpecializationsVisitor<'db> { + fn should_visit_lazy_type_attributes(&self) -> bool { + false + } + + fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>) { + match ty { + Type::NonInferableTypeVar(bound_typevar) | Type::TypeVar(bound_typevar) => { + let valid_specializations = bound_typevar.valid_specializations(db); + self.result + .borrow_mut() + .intersect(db, &valid_specializations); + } + _ => {} + } + + match TypeKind::from(ty) { + TypeKind::Atomic => {} + TypeKind::NonAtomic(non_atomic_type) => { + if !self.seen_types.borrow_mut().insert(non_atomic_type) { + // If we have already seen this type, we can skip it. + return; + } + walk_non_atomic_type(db, non_atomic_type, self); + } + } + } + } + + let visitor = ValidSpecializationsVisitor { + seen_types: RefCell::new(FxIndexSet::default()), + result: RefCell::new(ConstraintSet::from(true)), + }; + visitor.visit_type(db, self); + visitor.result.into_inner() + } +} diff --git a/crates/ty_python_semantic/src/types/visitor.rs b/crates/ty_python_semantic/src/types/visitor.rs index 35d9d5f1a1..aae08c5d9a 100644 --- a/crates/ty_python_semantic/src/types/visitor.rs +++ b/crates/ty_python_semantic/src/types/visitor.rs @@ -107,7 +107,7 @@ pub(crate) trait TypeVisitor<'db> { /// Enumeration of types that may contain other types, such as unions, intersections, and generics. #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -enum NonAtomicType<'db> { +pub(crate) enum NonAtomicType<'db> { Union(UnionType<'db>), Intersection(IntersectionType<'db>), FunctionLiteral(FunctionType<'db>), @@ -128,7 +128,7 @@ enum NonAtomicType<'db> { TypeAlias(TypeAliasType<'db>), } -enum TypeKind<'db> { +pub(crate) enum TypeKind<'db> { Atomic, NonAtomic(NonAtomicType<'db>), } @@ -200,7 +200,7 @@ impl<'db> From> for TypeKind<'db> { } } -fn walk_non_atomic_type<'db, V: TypeVisitor<'db> + ?Sized>( +pub(crate) fn walk_non_atomic_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, non_atomic_type: NonAtomicType<'db>, visitor: &V,