diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md index c2fb6cbc1b..3175ed7216 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md @@ -366,6 +366,31 @@ reveal_type(f(g("a"))) # revealed: tuple[Literal["a"] | None, int] reveal_type(g(f("a"))) # revealed: tuple[Literal["a"], int] | None ``` +## Passing generic functions to generic functions + +```py +from typing import Callable, TypeVar + +A = TypeVar("A") +B = TypeVar("B") +T = TypeVar("T") + +def invoke(fn: Callable[[A], B], value: A) -> B: + return fn(value) + +def identity(x: T) -> T: + return x + +def head(xs: list[T]) -> T: + return xs[0] + +# TODO: this should be `Literal[1]` +reveal_type(invoke(identity, 1)) # revealed: Unknown + +# TODO: this should be `Unknown | int` +reveal_type(invoke(head, [1, 2, 3])) # revealed: Unknown +``` + ## Opaque decorators don't affect typevar binding Inside the body of a generic function, we should be able to see that the typevars bound by that diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index 6dda932d4f..a9224d46c8 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -323,6 +323,27 @@ reveal_type(f(g("a"))) # revealed: tuple[Literal["a"] | None, int] reveal_type(g(f("a"))) # revealed: tuple[Literal["a"], int] | None ``` +## Passing generic functions to generic functions + +```py +from typing import Callable + +def invoke[A, B](fn: Callable[[A], B], value: A) -> B: + return fn(value) + +def identity[T](x: T) -> T: + return x + +def head[T](xs: list[T]) -> T: + return xs[0] + +# TODO: this should be `Literal[1]` +reveal_type(invoke(identity, 1)) # revealed: Unknown + +# TODO: this should be `Unknown | int` +reveal_type(invoke(head, [1, 2, 3])) # revealed: Unknown +``` + ## Protocols as TypeVar bounds Protocol types can be used as TypeVar bounds, just like nominal types. diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md b/crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md index d28d261fb1..486c354c24 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md @@ -321,8 +321,11 @@ a covariant generic, this is equivalent to using the upper bound of the type par `object`): ```py +from typing import Self + class Covariant[T]: - def get(self) -> T: + # TODO: remove the explicit `Self` annotation, once we support the implicit type of `self` + def get(self: Self) -> T: raise NotImplementedError def _(x: object): @@ -335,7 +338,8 @@ Similarly, contravariant type parameters use their lower bound of `Never`: ```py class Contravariant[T]: - def push(self, x: T) -> None: ... + # TODO: remove the explicit `Self` annotation, once we support the implicit type of `self` + def push(self: Self, x: T) -> None: ... def _(x: object): if isinstance(x, Contravariant): @@ -350,8 +354,10 @@ the type system, so we represent it with the internal `Top[]` special form. ```py class Invariant[T]: - def push(self, x: T) -> None: ... - def get(self) -> T: + # TODO: remove the explicit `Self` annotation, once we support the implicit type of `self` + def push(self: Self, x: T) -> None: ... + # TODO: remove the explicit `Self` annotation, once we support the implicit type of `self` + def get(self: Self) -> T: raise NotImplementedError def _(x: object): diff --git a/crates/ty_python_semantic/resources/mdtest/overloads.md b/crates/ty_python_semantic/resources/mdtest/overloads.md index 08bd0c3021..8e8e9e524b 100644 --- a/crates/ty_python_semantic/resources/mdtest/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/overloads.md @@ -99,7 +99,7 @@ reveal_type(foo(b"")) # revealed: bytes ## Methods ```py -from typing import overload +from typing_extensions import Self, overload class Foo1: @overload @@ -126,6 +126,18 @@ foo2 = Foo2() reveal_type(foo2.method) # revealed: Overload[() -> None, (x: str) -> str] reveal_type(foo2.method()) # revealed: None reveal_type(foo2.method("")) # revealed: str + +class Foo3: + @overload + def takes_self_or_int(self: Self, x: Self) -> Self: ... + @overload + def takes_self_or_int(self: Self, x: int) -> int: ... + def takes_self_or_int(self: Self, x: Self | int) -> Self | int: + return x + +foo3 = Foo3() +reveal_type(foo3.takes_self_or_int(foo3)) # revealed: Foo3 +reveal_type(foo3.takes_self_or_int(1)) # revealed: int ``` ## Constructor diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 8e90c205e1..4fb4385d1e 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1593,9 +1593,16 @@ impl<'db> Type<'db> { }) } + (Type::TypeVar(_), _) if relation.is_assignability() => { + // The implicit lower bound of a typevar is `Never`, which means + // that it is always assignable to any other type. + + // TODO: record the unification constraints + + ConstraintSet::from(true) + } + // `Never` is the bottom type, the empty set. - // Other than one unlikely edge case (TypeVars bound to `Never`), - // no other type is a subtype of or assignable to `Never`. (_, Type::Never) => ConstraintSet::from(false), (Type::Union(union), _) => union.elements(db).iter().when_all(db, |&elem_ty| { @@ -1632,6 +1639,22 @@ impl<'db> Type<'db> { // be specialized to `Never`.) (_, Type::NonInferableTypeVar(_)) => ConstraintSet::from(false), + (_, Type::TypeVar(typevar)) + if relation.is_assignability() + && typevar.typevar(db).upper_bound(db).is_none_or(|bound| { + !self + .has_relation_to_impl(db, bound, relation, visitor) + .is_never_satisfied() + }) => + { + // TODO: record the unification constraints + + typevar + .typevar(db) + .upper_bound(db) + .when_none_or(|bound| self.has_relation_to_impl(db, bound, relation, visitor)) + } + // TODO: Infer specializations here (Type::TypeVar(_), _) | (_, Type::TypeVar(_)) => ConstraintSet::from(false), @@ -5662,13 +5685,25 @@ impl<'db> Type<'db> { ], }); }; - let instance = Type::instance(db, class.unknown_specialization(db)); + + let upper_bound = Type::instance( + db, + class.apply_specialization(db, |generic_context| { + let types = generic_context + .variables(db) + .iter() + .map(|typevar| Type::NonInferableTypeVar(*typevar)); + + generic_context.specialize(db, types.collect()) + }), + ); + let class_definition = class.definition(db); let typevar = TypeVarInstance::new( db, ast::name::Name::new_static("Self"), Some(class_definition), - Some(TypeVarBoundOrConstraints::UpperBound(instance).into()), + Some(TypeVarBoundOrConstraints::UpperBound(upper_bound).into()), // According to the [spec], we can consider `Self` // equivalent to an invariant type variable // [spec]: https://typing.python.org/en/latest/spec/generics.html#self @@ -6010,8 +6045,8 @@ impl<'db> Type<'db> { partial.get(db, bound_typevar).unwrap_or(self) } TypeMapping::MarkTypeVarsInferable(binding_context) => { - if bound_typevar.binding_context(db) == *binding_context { - Type::TypeVar(bound_typevar) + if binding_context.is_none_or(|context| context == bound_typevar.binding_context(db)) { + Type::TypeVar(bound_typevar.mark_typevars_inferable(db, visitor)) } else { self } @@ -6695,8 +6730,17 @@ pub enum TypeMapping<'a, 'db> { BindSelf(Type<'db>), /// Replaces occurrences of `typing.Self` with a new `Self` type variable with the given upper bound. ReplaceSelf { new_upper_bound: Type<'db> }, - /// Marks the typevars that are bound by a generic class or function as inferable. - MarkTypeVarsInferable(BindingContext<'db>), + /// Marks type variables as inferable. + /// + /// When we create the signature for a generic function, we mark its type variables as inferable. Since + /// the generic function might reference type variables from enclosing generic scopes, we include the + /// function's binding context in order to only mark those type variables as inferable that are actually + /// bound by that function. + /// + /// When the parameter is set to `None`, *all* type variables will be marked as inferable. We use this + /// variant when descending into the bounds and/or constraints, and the default value of a type variable, + /// which may include nested type variables (`Self` has a bound of `C[T]` for a generic class `C[T]`). + MarkTypeVarsInferable(Option>), /// Create the top or bottom materialization of a type. Materialize(MaterializationKind), } @@ -7637,6 +7681,43 @@ impl<'db> TypeVarInstance<'db> { ) } + fn mark_typevars_inferable( + self, + db: &'db dyn Db, + visitor: &ApplyTypeMappingVisitor<'db>, + ) -> Self { + // Type variables can have nested type variables in their bounds, constraints, or default value. + // When we mark a type variable as inferable, we also mark all of these nested type variables as + // inferable, so we set the parameter to `None` here. + let type_mapping = &TypeMapping::MarkTypeVarsInferable(None); + + Self::new( + db, + self.name(db), + self.definition(db), + self._bound_or_constraints(db) + .map(|bound_or_constraints| match bound_or_constraints { + TypeVarBoundOrConstraintsEvaluation::Eager(bound_or_constraints) => { + bound_or_constraints + .mark_typevars_inferable(db, visitor) + .into() + } + TypeVarBoundOrConstraintsEvaluation::LazyUpperBound + | TypeVarBoundOrConstraintsEvaluation::LazyConstraints => bound_or_constraints, + }), + self.explicit_variance(db), + self._default(db).and_then(|default| match default { + TypeVarDefaultEvaluation::Eager(ty) => { + Some(ty.apply_type_mapping_impl(db, type_mapping, visitor).into()) + } + TypeVarDefaultEvaluation::Lazy => self + .lazy_default(db) + .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor).into()), + }), + self.kind(db), + ) + } + fn to_instance(self, db: &'db dyn Db) -> Option { let bound_or_constraints = match self.bound_or_constraints(db)? { TypeVarBoundOrConstraints::UpperBound(upper_bound) => { @@ -7867,6 +7948,18 @@ impl<'db> BoundTypeVarInstance<'db> { ) } + fn mark_typevars_inferable( + self, + db: &'db dyn Db, + visitor: &ApplyTypeMappingVisitor<'db>, + ) -> Self { + Self::new( + db, + self.typevar(db).mark_typevars_inferable(db, visitor), + self.binding_context(db), + ) + } + fn to_instance(self, db: &'db dyn Db) -> Option { Some(Self::new( db, @@ -7972,6 +8065,31 @@ impl<'db> TypeVarBoundOrConstraints<'db> { } } } + + fn mark_typevars_inferable( + self, + db: &'db dyn Db, + visitor: &ApplyTypeMappingVisitor<'db>, + ) -> Self { + let type_mapping = &TypeMapping::MarkTypeVarsInferable(None); + + match self { + TypeVarBoundOrConstraints::UpperBound(bound) => TypeVarBoundOrConstraints::UpperBound( + bound.apply_type_mapping_impl(db, type_mapping, visitor), + ), + TypeVarBoundOrConstraints::Constraints(constraints) => { + TypeVarBoundOrConstraints::Constraints(UnionType::new( + db, + constraints + .elements(db) + .iter() + .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor)) + .collect::>() + .into_boxed_slice(), + )) + } + } + } } /// Error returned if a type is not awaitable. diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 1677a7ea6f..7650a4eac3 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -491,6 +491,18 @@ fn is_subtype_in_invariant_position<'db>( let base_bottom = base_type.bottom_materialization(db); let is_subtype_of = |derived: Type<'db>, base: Type<'db>| { + // TODO: + // This should be removed and properly handled in the respective + // `(Type::TypeVar(_), _) | (_, Type::TypeVar(_))` branch of + // `Type::has_relation_to_impl`. Right now, we can not generally + // return `ConstraintSet::from(true)` from that branch, as that + // leads to union simplification, which means that we lose track + // of type variables without recording the constraints under which + // the relation holds. + if matches!(base, Type::TypeVar(_)) || matches!(derived, Type::TypeVar(_)) { + return ConstraintSet::from(true); + } + derived.has_relation_to_impl(db, base, TypeRelation::Subtyping, visitor) }; match (derived_materialization, base_materialization) { diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 35da85a09f..5386293cfb 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -367,7 +367,9 @@ impl<'db> Signature<'db> { let plain_return_ty = definition_expression_type(db, definition, returns.as_ref()) .apply_type_mapping( db, - &TypeMapping::MarkTypeVarsInferable(BindingContext::Definition(definition)), + &TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition( + definition, + ))), ); if function_node.is_async && !is_generator { KnownClass::CoroutineType @@ -1549,7 +1551,9 @@ impl<'db> Parameter<'db> { annotated_type: parameter.annotation().map(|annotation| { definition_expression_type(db, definition, annotation).apply_type_mapping( db, - &TypeMapping::MarkTypeVarsInferable(BindingContext::Definition(definition)), + &TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition( + definition, + ))), ) }), kind,