diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md index 03e281d03c..b078d38d0e 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md @@ -150,6 +150,17 @@ reveal_type(Constrained[int | str]()) # revealed: Constrained[int | str] reveal_type(Constrained[object]()) # revealed: Unknown ``` +If the type variable has a default, it can be omitted: + +```py +WithDefaultU = TypeVar("WithDefaultU", default=int) + +class WithDefault(Generic[T, WithDefaultU]): ... + +reveal_type(WithDefault[str, str]()) # revealed: WithDefault[str, str] +reveal_type(WithDefault[str]()) # revealed: WithDefault[str, int] +``` + ## Inferring generic class parameters We can infer the type parameter from a type context: diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md index 736687d420..39bb9cfb62 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md @@ -133,6 +133,15 @@ reveal_type(Constrained[int | str]()) # revealed: Constrained[int | str] reveal_type(Constrained[object]()) # revealed: Unknown ``` +If the type variable has a default, it can be omitted: + +```py +class WithDefault[T, U = int]: ... + +reveal_type(WithDefault[str, str]()) # revealed: WithDefault[str, str] +reveal_type(WithDefault[str]()) # revealed: WithDefault[str, int] +``` + ## Inferring generic class parameters We can infer the type parameter from a type context: diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index d6068e8f55..3f93985a43 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -1316,10 +1316,27 @@ impl<'db> Binding<'db> { self.inherited_specialization } + /// Returns the bound types for each parameter, in parameter source order, or `None` if no + /// argument was matched to that parameter. pub(crate) fn parameter_types(&self) -> &[Option>] { &self.parameter_tys } + /// Returns the bound types for each parameter, in parameter source order, with default values + /// applied for arguments that weren't matched to a parameter. Returns `None` if there are any + /// non-default arguments that weren't matched to a parameter. + pub(crate) fn parameter_types_with_defaults( + &self, + signature: &Signature<'db>, + ) -> Option]>> { + signature + .parameters() + .iter() + .zip(&self.parameter_tys) + .map(|(parameter, parameter_ty)| parameter_ty.or(parameter.default_type())) + .collect() + } + pub(crate) fn arguments_for_parameter<'a>( &'a self, argument_types: &'a CallArgumentTypes<'a, 'db>, diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 0210cd2062..a436f880af 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -123,6 +123,9 @@ impl<'db> GenericContext<'db> { } None => {} } + if let Some(default_ty) = typevar.default_ty(db) { + parameter = parameter.with_default_type(default_ty); + } parameter } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index cc4745b768..87d9a44fb7 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -6711,10 +6711,8 @@ impl<'db> TypeInferenceBuilder<'db> { } _ => CallArgumentTypes::positional([self.infer_type_expression(slice_node)]), }; - let signatures = Signatures::single(CallableSignature::single( - value_ty, - generic_context.signature(self.db()), - )); + let signature = generic_context.signature(self.db()); + let signatures = Signatures::single(CallableSignature::single(value_ty, signature.clone())); let bindings = match Bindings::match_parameters(signatures, &call_argument_types) .check_types(self.db(), &call_argument_types) { @@ -6732,14 +6730,10 @@ impl<'db> TypeInferenceBuilder<'db> { .matching_overloads() .next() .expect("valid bindings should have matching overload"); - let specialization = generic_context.specialize( - self.db(), - overload - .parameter_types() - .iter() - .map(|ty| ty.unwrap_or(Type::unknown())) - .collect(), - ); + let parameters = overload + .parameter_types_with_defaults(&signature) + .expect("matching overload should not have missing arguments"); + let specialization = generic_context.specialize(self.db(), parameters); Type::from(GenericAlias::new(self.db(), generic_class, specialization)) }