use type context for inference of generic function calls

This commit is contained in:
Ibraheem Ahmed 2025-09-18 16:50:41 -04:00
parent 44fc87f491
commit 5f294f9f2e
5 changed files with 90 additions and 15 deletions

View file

@ -234,3 +234,48 @@ reveal_type(x) # revealed: Foo
x: int = 1 x: int = 1
reveal_type(x) # revealed: Literal[1] reveal_type(x) # revealed: Literal[1]
``` ```
## Annotations influence generic call inference
```toml
[environment]
python-version = "3.12"
```
```py
from typing import Literal
def f[T](x: T) -> list[T]:
return [x]
a = f("a")
reveal_type(a) # revealed: list[Literal["a"]]
b: list[int | Literal["a"]] = f("a")
reveal_type(b) # revealed: list[int | Literal["a"]]
c: list[int | str] = f("a")
reveal_type(c) # revealed: list[int | str]
d: list[int | tuple[int, int]] = f((1, 2))
reveal_type(d) # revealed: list[int | tuple[int, int]]
e: list[int] = f(True)
reveal_type(e) # revealed: list[int]
# TODO: the RHS should be inferred as `list[Literal["a"]]` here
# error: [invalid-assignment] "Object of type `list[int | Literal["a"]]` is not assignable to `list[int]`"
g: list[int] = f("a")
# error: [invalid-assignment] "Object of type `list[Literal["a"]]` is not assignable to `tuple[int]`"
h: tuple[int] = f("a")
def f2[T: int](x: T) -> T:
return x
i: int = f2(True)
reveal_type(i) # revealed: int
j: int | str = f2(True)
reveal_type(j) # revealed: Literal[True]
```

View file

@ -4805,7 +4805,7 @@ impl<'db> Type<'db> {
) -> Result<Bindings<'db>, CallError<'db>> { ) -> Result<Bindings<'db>, CallError<'db>> {
self.bindings(db) self.bindings(db)
.match_parameters(db, argument_types) .match_parameters(db, argument_types)
.check_types(db, argument_types) .check_types(db, argument_types, &TypeContext::default())
} }
/// Look up a dunder method on the meta-type of `self` and call it. /// Look up a dunder method on the meta-type of `self` and call it.
@ -4854,7 +4854,7 @@ impl<'db> Type<'db> {
let bindings = dunder_callable let bindings = dunder_callable
.bindings(db) .bindings(db)
.match_parameters(db, argument_types) .match_parameters(db, argument_types)
.check_types(db, argument_types)?; .check_types(db, argument_types, &TypeContext::default())?;
if boundness == Boundness::PossiblyUnbound { if boundness == Boundness::PossiblyUnbound {
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
} }

View file

@ -32,8 +32,8 @@ use crate::types::tuple::{TupleLength, TupleType};
use crate::types::{ use crate::types::{
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType, BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType, KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
TrackedConstraintSet, TypeAliasType, TypeMapping, UnionType, WrapperDescriptorKind, enums, TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionType,
ide_support, todo_type, WrapperDescriptorKind, enums, ide_support, todo_type,
}; };
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
use ruff_python_ast::{self as ast, PythonVersion}; use ruff_python_ast::{self as ast, PythonVersion};
@ -122,6 +122,9 @@ impl<'db> Bindings<'db> {
/// You must provide an `argument_types` that was created from the same `arguments` that you /// You must provide an `argument_types` that was created from the same `arguments` that you
/// provided to [`match_parameters`][Self::match_parameters]. /// provided to [`match_parameters`][Self::match_parameters].
/// ///
/// The type context of the call expression is also used to infer the specialization of generic
/// calls.
///
/// We update the bindings to include the return type of the call, the bound types for all /// We update the bindings to include the return type of the call, the bound types for all
/// parameters, and any errors resulting from binding the call, all for each union element and /// parameters, and any errors resulting from binding the call, all for each union element and
/// overload (if any). /// overload (if any).
@ -129,9 +132,12 @@ impl<'db> Bindings<'db> {
mut self, mut self,
db: &'db dyn Db, db: &'db dyn Db,
argument_types: &CallArguments<'_, 'db>, argument_types: &CallArguments<'_, 'db>,
call_expression_tcx: &TypeContext<'db>,
) -> Result<Self, CallError<'db>> { ) -> Result<Self, CallError<'db>> {
for element in &mut self.elements { for element in &mut self.elements {
if let Some(mut updated_argument_forms) = element.check_types(db, argument_types) { if let Some(mut updated_argument_forms) =
element.check_types(db, argument_types, call_expression_tcx)
{
// If this element returned a new set of argument forms (indicating successful // If this element returned a new set of argument forms (indicating successful
// argument type expansion), update the `Bindings` with these forms. // argument type expansion), update the `Bindings` with these forms.
updated_argument_forms.shrink_to_fit(); updated_argument_forms.shrink_to_fit();
@ -1281,6 +1287,7 @@ impl<'db> CallableBinding<'db> {
&mut self, &mut self,
db: &'db dyn Db, db: &'db dyn Db,
argument_types: &CallArguments<'_, 'db>, argument_types: &CallArguments<'_, 'db>,
call_expression_tcx: &TypeContext<'db>,
) -> Option<ArgumentForms> { ) -> Option<ArgumentForms> {
// If this callable is a bound method, prepend the self instance onto the arguments list // If this callable is a bound method, prepend the self instance onto the arguments list
// before checking. // before checking.
@ -1293,7 +1300,7 @@ impl<'db> CallableBinding<'db> {
// still perform type checking for non-overloaded function to provide better user // still perform type checking for non-overloaded function to provide better user
// experience. // experience.
if let [overload] = self.overloads.as_mut_slice() { if let [overload] = self.overloads.as_mut_slice() {
overload.check_types(db, argument_types.as_ref()); overload.check_types(db, argument_types.as_ref(), call_expression_tcx);
} }
return None; return None;
} }
@ -1301,7 +1308,7 @@ impl<'db> CallableBinding<'db> {
// If only one candidate overload remains, it is the winning match. Evaluate it as // If only one candidate overload remains, it is the winning match. Evaluate it as
// a regular (non-overloaded) call. // a regular (non-overloaded) call.
self.matching_overload_index = Some(index); self.matching_overload_index = Some(index);
self.overloads[index].check_types(db, argument_types.as_ref()); self.overloads[index].check_types(db, argument_types.as_ref(), call_expression_tcx);
return None; return None;
} }
MatchingOverloadIndex::Multiple(indexes) => { MatchingOverloadIndex::Multiple(indexes) => {
@ -1313,7 +1320,7 @@ impl<'db> CallableBinding<'db> {
// Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine // Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
// whether it is compatible with the supplied argument list. // whether it is compatible with the supplied argument list.
for (_, overload) in self.matching_overloads_mut() { for (_, overload) in self.matching_overloads_mut() {
overload.check_types(db, argument_types.as_ref()); overload.check_types(db, argument_types.as_ref(), call_expression_tcx);
} }
match self.matching_overload_index() { match self.matching_overload_index() {
@ -1430,7 +1437,7 @@ impl<'db> CallableBinding<'db> {
merged_argument_forms.merge(&argument_forms); merged_argument_forms.merge(&argument_forms);
for (_, overload) in self.matching_overloads_mut() { for (_, overload) in self.matching_overloads_mut() {
overload.check_types(db, expanded_arguments); overload.check_types(db, expanded_arguments, call_expression_tcx);
} }
let return_type = match self.matching_overload_index() { let return_type = match self.matching_overload_index() {
@ -2243,6 +2250,7 @@ struct ArgumentTypeChecker<'a, 'db> {
arguments: &'a CallArguments<'a, 'db>, arguments: &'a CallArguments<'a, 'db>,
argument_matches: &'a [MatchedArgument<'db>], argument_matches: &'a [MatchedArgument<'db>],
parameter_tys: &'a mut [Option<Type<'db>>], parameter_tys: &'a mut [Option<Type<'db>>],
call_expression_tcx: &'a TypeContext<'db>,
errors: &'a mut Vec<BindingError<'db>>, errors: &'a mut Vec<BindingError<'db>>,
specialization: Option<Specialization<'db>>, specialization: Option<Specialization<'db>>,
@ -2256,6 +2264,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
arguments: &'a CallArguments<'a, 'db>, arguments: &'a CallArguments<'a, 'db>,
argument_matches: &'a [MatchedArgument<'db>], argument_matches: &'a [MatchedArgument<'db>],
parameter_tys: &'a mut [Option<Type<'db>>], parameter_tys: &'a mut [Option<Type<'db>>],
call_expression_tcx: &'a TypeContext<'db>,
errors: &'a mut Vec<BindingError<'db>>, errors: &'a mut Vec<BindingError<'db>>,
) -> Self { ) -> Self {
Self { Self {
@ -2264,6 +2273,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
arguments, arguments,
argument_matches, argument_matches,
parameter_tys, parameter_tys,
call_expression_tcx,
errors, errors,
specialization: None, specialization: None,
inherited_specialization: None, inherited_specialization: None,
@ -2304,8 +2314,20 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
return; return;
} }
let parameters = self.signature.parameters();
let mut builder = SpecializationBuilder::new(self.db); let mut builder = SpecializationBuilder::new(self.db);
// Note that we infer the annotated type _before_ the arguments if this call is part of
// an annotated assignment, to closer match the order of any unions written in the type
// annotation.
if let Some(return_ty) = self.signature.return_ty
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
{
// Ignore any specialization errors here, because the type context is only used to
// optionally widen the return type.
let _ = builder.infer(return_ty, call_expression_tcx);
}
let parameters = self.signature.parameters();
for (argument_index, adjusted_argument_index, _, argument_type) in for (argument_index, adjusted_argument_index, _, argument_type) in
self.enumerate_argument_types() self.enumerate_argument_types()
{ {
@ -2316,6 +2338,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
let Some(expected_type) = parameter.annotated_type() else { let Some(expected_type) = parameter.annotated_type() else {
continue; continue;
}; };
if let Err(error) = builder.infer( if let Err(error) = builder.infer(
expected_type, expected_type,
variadic_argument_type.unwrap_or(argument_type), variadic_argument_type.unwrap_or(argument_type),
@ -2327,6 +2350,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
} }
} }
} }
self.specialization = self.signature.generic_context.map(|gc| builder.build(gc)); self.specialization = self.signature.generic_context.map(|gc| builder.build(gc));
self.inherited_specialization = self.signature.inherited_generic_context.map(|gc| { self.inherited_specialization = self.signature.inherited_generic_context.map(|gc| {
// The inherited generic context is used when inferring the specialization of a generic // The inherited generic context is used when inferring the specialization of a generic
@ -2688,13 +2712,19 @@ impl<'db> Binding<'db> {
self.argument_matches = matcher.finish(); self.argument_matches = matcher.finish();
} }
fn check_types(&mut self, db: &'db dyn Db, arguments: &CallArguments<'_, 'db>) { fn check_types(
&mut self,
db: &'db dyn Db,
arguments: &CallArguments<'_, 'db>,
call_expression_tcx: &TypeContext<'db>,
) {
let mut checker = ArgumentTypeChecker::new( let mut checker = ArgumentTypeChecker::new(
db, db,
&self.signature, &self.signature,
arguments, arguments,
&self.argument_matches, &self.argument_matches,
&mut self.parameter_tys, &mut self.parameter_tys,
call_expression_tcx,
&mut self.errors, &mut self.errors,
); );

View file

@ -352,7 +352,7 @@ struct ExpressionWithContext<'db> {
/// more precise inference results, aka "bidirectional type inference". /// more precise inference results, aka "bidirectional type inference".
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)] #[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
pub(crate) struct TypeContext<'db> { pub(crate) struct TypeContext<'db> {
annotation: Option<Type<'db>>, pub(crate) annotation: Option<Type<'db>>,
} }
impl<'db> TypeContext<'db> { impl<'db> TypeContext<'db> {

View file

@ -5775,7 +5775,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_call_expression( fn infer_call_expression(
&mut self, &mut self,
call_expression: &ast::ExprCall, call_expression: &ast::ExprCall,
_tcx: TypeContext<'db>, tcx: TypeContext<'db>,
) -> Type<'db> { ) -> Type<'db> {
let ast::ExprCall { let ast::ExprCall {
range: _, range: _,
@ -5955,7 +5955,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
} }
let mut bindings = match bindings.check_types(self.db(), &call_arguments) { let mut bindings = match bindings.check_types(self.db(), &call_arguments, &tcx) {
Ok(bindings) => bindings, Ok(bindings) => bindings,
Err(CallError(_, bindings)) => { Err(CallError(_, bindings)) => {
bindings.report_diagnostics(&self.context, call_expression.into()); bindings.report_diagnostics(&self.context, call_expression.into());
@ -8521,7 +8521,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let binding = Binding::single(value_ty, generic_context.signature(self.db())); let binding = Binding::single(value_ty, generic_context.signature(self.db()));
let bindings = match Bindings::from(binding) let bindings = match Bindings::from(binding)
.match_parameters(self.db(), &call_argument_types) .match_parameters(self.db(), &call_argument_types)
.check_types(self.db(), &call_argument_types) .check_types(self.db(), &call_argument_types, &TypeContext::default())
{ {
Ok(bindings) => bindings, Ok(bindings) => bindings,
Err(CallError(_, bindings)) => { Err(CallError(_, bindings)) => {