diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index 62852bc0ba..a8e064c3b5 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -234,3 +234,48 @@ reveal_type(x) # revealed: Foo x: int = 1 reveal_type(x) # revealed: Literal[1] ``` + +## Annotations influence generic call inference + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import Literal + +def f[T](x: T) -> list[T]: + return [x] + +a = f("a") +reveal_type(a) # revealed: list[Literal["a"]] + +b: list[int | Literal["a"]] = f("a") +reveal_type(b) # revealed: list[int | Literal["a"]] + +c: list[int | str] = f("a") +reveal_type(c) # revealed: list[int | str] + +d: list[int | tuple[int, int]] = f((1, 2)) +reveal_type(d) # revealed: list[int | tuple[int, int]] + +e: list[int] = f(True) +reveal_type(e) # revealed: list[int] + +# TODO: the RHS should be inferred as `list[Literal["a"]]` here +# error: [invalid-assignment] "Object of type `list[int | Literal["a"]]` is not assignable to `list[int]`" +g: list[int] = f("a") + +# error: [invalid-assignment] "Object of type `list[Literal["a"]]` is not assignable to `tuple[int]`" +h: tuple[int] = f("a") + +def f2[T: int](x: T) -> T: + return x + +i: int = f2(True) +reveal_type(i) # revealed: int + +j: int | str = f2(True) +reveal_type(j) # revealed: Literal[True] +``` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index bf68255901..57b403b307 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4805,7 +4805,7 @@ impl<'db> Type<'db> { ) -> Result, CallError<'db>> { self.bindings(db) .match_parameters(db, argument_types) - .check_types(db, argument_types) + .check_types(db, argument_types, &TypeContext::default()) } /// Look up a dunder method on the meta-type of `self` and call it. @@ -4854,7 +4854,7 @@ impl<'db> Type<'db> { let bindings = dunder_callable .bindings(db) .match_parameters(db, argument_types) - .check_types(db, argument_types)?; + .check_types(db, argument_types, &TypeContext::default())?; if boundness == Boundness::PossiblyUnbound { return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); } diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index c1da805d48..e19e1c6c04 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -32,8 +32,8 @@ use crate::types::tuple::{TupleLength, TupleType}; use crate::types::{ BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType, - TrackedConstraintSet, TypeAliasType, TypeMapping, UnionType, WrapperDescriptorKind, enums, - ide_support, todo_type, + TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionType, + WrapperDescriptorKind, enums, ide_support, todo_type, }; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; use ruff_python_ast::{self as ast, PythonVersion}; @@ -122,6 +122,9 @@ impl<'db> Bindings<'db> { /// You must provide an `argument_types` that was created from the same `arguments` that you /// provided to [`match_parameters`][Self::match_parameters]. /// + /// The type context of the call expression is also used to infer the specialization of generic + /// calls. + /// /// We update the bindings to include the return type of the call, the bound types for all /// parameters, and any errors resulting from binding the call, all for each union element and /// overload (if any). @@ -129,9 +132,12 @@ impl<'db> Bindings<'db> { mut self, db: &'db dyn Db, argument_types: &CallArguments<'_, 'db>, + call_expression_tcx: &TypeContext<'db>, ) -> Result> { for element in &mut self.elements { - if let Some(mut updated_argument_forms) = element.check_types(db, argument_types) { + if let Some(mut updated_argument_forms) = + element.check_types(db, argument_types, call_expression_tcx) + { // If this element returned a new set of argument forms (indicating successful // argument type expansion), update the `Bindings` with these forms. updated_argument_forms.shrink_to_fit(); @@ -1281,6 +1287,7 @@ impl<'db> CallableBinding<'db> { &mut self, db: &'db dyn Db, argument_types: &CallArguments<'_, 'db>, + call_expression_tcx: &TypeContext<'db>, ) -> Option { // If this callable is a bound method, prepend the self instance onto the arguments list // before checking. @@ -1293,7 +1300,7 @@ impl<'db> CallableBinding<'db> { // still perform type checking for non-overloaded function to provide better user // experience. if let [overload] = self.overloads.as_mut_slice() { - overload.check_types(db, argument_types.as_ref()); + overload.check_types(db, argument_types.as_ref(), call_expression_tcx); } return None; } @@ -1301,7 +1308,7 @@ impl<'db> CallableBinding<'db> { // If only one candidate overload remains, it is the winning match. Evaluate it as // a regular (non-overloaded) call. self.matching_overload_index = Some(index); - self.overloads[index].check_types(db, argument_types.as_ref()); + self.overloads[index].check_types(db, argument_types.as_ref(), call_expression_tcx); return None; } MatchingOverloadIndex::Multiple(indexes) => { @@ -1313,7 +1320,7 @@ impl<'db> CallableBinding<'db> { // Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine // whether it is compatible with the supplied argument list. for (_, overload) in self.matching_overloads_mut() { - overload.check_types(db, argument_types.as_ref()); + overload.check_types(db, argument_types.as_ref(), call_expression_tcx); } match self.matching_overload_index() { @@ -1430,7 +1437,7 @@ impl<'db> CallableBinding<'db> { merged_argument_forms.merge(&argument_forms); for (_, overload) in self.matching_overloads_mut() { - overload.check_types(db, expanded_arguments); + overload.check_types(db, expanded_arguments, call_expression_tcx); } let return_type = match self.matching_overload_index() { @@ -2243,6 +2250,7 @@ struct ArgumentTypeChecker<'a, 'db> { arguments: &'a CallArguments<'a, 'db>, argument_matches: &'a [MatchedArgument<'db>], parameter_tys: &'a mut [Option>], + call_expression_tcx: &'a TypeContext<'db>, errors: &'a mut Vec>, specialization: Option>, @@ -2256,6 +2264,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { arguments: &'a CallArguments<'a, 'db>, argument_matches: &'a [MatchedArgument<'db>], parameter_tys: &'a mut [Option>], + call_expression_tcx: &'a TypeContext<'db>, errors: &'a mut Vec>, ) -> Self { Self { @@ -2264,6 +2273,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { arguments, argument_matches, parameter_tys, + call_expression_tcx, errors, specialization: None, inherited_specialization: None, @@ -2304,8 +2314,20 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { return; } - let parameters = self.signature.parameters(); let mut builder = SpecializationBuilder::new(self.db); + + // Note that we infer the annotated type _before_ the arguments if this call is part of + // an annotated assignment, to closer match the order of any unions written in the type + // annotation. + if let Some(return_ty) = self.signature.return_ty + && let Some(call_expression_tcx) = self.call_expression_tcx.annotation + { + // Ignore any specialization errors here, because the type context is only used to + // optionally widen the return type. + let _ = builder.infer(return_ty, call_expression_tcx); + } + + let parameters = self.signature.parameters(); for (argument_index, adjusted_argument_index, _, argument_type) in self.enumerate_argument_types() { @@ -2316,6 +2338,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { let Some(expected_type) = parameter.annotated_type() else { continue; }; + if let Err(error) = builder.infer( expected_type, variadic_argument_type.unwrap_or(argument_type), @@ -2327,6 +2350,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { } } } + self.specialization = self.signature.generic_context.map(|gc| builder.build(gc)); self.inherited_specialization = self.signature.inherited_generic_context.map(|gc| { // The inherited generic context is used when inferring the specialization of a generic @@ -2688,13 +2712,19 @@ impl<'db> Binding<'db> { self.argument_matches = matcher.finish(); } - fn check_types(&mut self, db: &'db dyn Db, arguments: &CallArguments<'_, 'db>) { + fn check_types( + &mut self, + db: &'db dyn Db, + arguments: &CallArguments<'_, 'db>, + call_expression_tcx: &TypeContext<'db>, + ) { let mut checker = ArgumentTypeChecker::new( db, &self.signature, arguments, &self.argument_matches, &mut self.parameter_tys, + call_expression_tcx, &mut self.errors, ); diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 7f81248a22..cbeef15ab9 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -352,7 +352,7 @@ struct ExpressionWithContext<'db> { /// more precise inference results, aka "bidirectional type inference". #[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)] pub(crate) struct TypeContext<'db> { - annotation: Option>, + pub(crate) annotation: Option>, } impl<'db> TypeContext<'db> { diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index ad774d38ea..c5cf398c24 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -5775,7 +5775,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { fn infer_call_expression( &mut self, call_expression: &ast::ExprCall, - _tcx: TypeContext<'db>, + tcx: TypeContext<'db>, ) -> Type<'db> { let ast::ExprCall { range: _, @@ -5955,7 +5955,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - let mut bindings = match bindings.check_types(self.db(), &call_arguments) { + let mut bindings = match bindings.check_types(self.db(), &call_arguments, &tcx) { Ok(bindings) => bindings, Err(CallError(_, bindings)) => { bindings.report_diagnostics(&self.context, call_expression.into()); @@ -8521,7 +8521,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let binding = Binding::single(value_ty, generic_context.signature(self.db())); let bindings = match Bindings::from(binding) .match_parameters(self.db(), &call_argument_types) - .check_types(self.db(), &call_argument_types) + .check_types(self.db(), &call_argument_types, &TypeContext::default()) { Ok(bindings) => bindings, Err(CallError(_, bindings)) => {