use type context for inference of generic function calls

This commit is contained in:
Ibraheem Ahmed 2025-09-18 16:50:41 -04:00
parent 44fc87f491
commit 5f294f9f2e
5 changed files with 90 additions and 15 deletions

View file

@ -234,3 +234,48 @@ reveal_type(x) # revealed: Foo
x: int = 1
reveal_type(x) # revealed: Literal[1]
```
## Annotations influence generic call inference
```toml
[environment]
python-version = "3.12"
```
```py
from typing import Literal
def f[T](x: T) -> list[T]:
return [x]
a = f("a")
reveal_type(a) # revealed: list[Literal["a"]]
b: list[int | Literal["a"]] = f("a")
reveal_type(b) # revealed: list[int | Literal["a"]]
c: list[int | str] = f("a")
reveal_type(c) # revealed: list[int | str]
d: list[int | tuple[int, int]] = f((1, 2))
reveal_type(d) # revealed: list[int | tuple[int, int]]
e: list[int] = f(True)
reveal_type(e) # revealed: list[int]
# TODO: the RHS should be inferred as `list[Literal["a"]]` here
# error: [invalid-assignment] "Object of type `list[int | Literal["a"]]` is not assignable to `list[int]`"
g: list[int] = f("a")
# error: [invalid-assignment] "Object of type `list[Literal["a"]]` is not assignable to `tuple[int]`"
h: tuple[int] = f("a")
def f2[T: int](x: T) -> T:
return x
i: int = f2(True)
reveal_type(i) # revealed: int
j: int | str = f2(True)
reveal_type(j) # revealed: Literal[True]
```

View file

@ -4805,7 +4805,7 @@ impl<'db> Type<'db> {
) -> Result<Bindings<'db>, CallError<'db>> {
self.bindings(db)
.match_parameters(db, argument_types)
.check_types(db, argument_types)
.check_types(db, argument_types, &TypeContext::default())
}
/// Look up a dunder method on the meta-type of `self` and call it.
@ -4854,7 +4854,7 @@ impl<'db> Type<'db> {
let bindings = dunder_callable
.bindings(db)
.match_parameters(db, argument_types)
.check_types(db, argument_types)?;
.check_types(db, argument_types, &TypeContext::default())?;
if boundness == Boundness::PossiblyUnbound {
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
}

View file

@ -32,8 +32,8 @@ use crate::types::tuple::{TupleLength, TupleType};
use crate::types::{
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
TrackedConstraintSet, TypeAliasType, TypeMapping, UnionType, WrapperDescriptorKind, enums,
ide_support, todo_type,
TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionType,
WrapperDescriptorKind, enums, ide_support, todo_type,
};
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
use ruff_python_ast::{self as ast, PythonVersion};
@ -122,6 +122,9 @@ impl<'db> Bindings<'db> {
/// You must provide an `argument_types` that was created from the same `arguments` that you
/// provided to [`match_parameters`][Self::match_parameters].
///
/// The type context of the call expression is also used to infer the specialization of generic
/// calls.
///
/// We update the bindings to include the return type of the call, the bound types for all
/// parameters, and any errors resulting from binding the call, all for each union element and
/// overload (if any).
@ -129,9 +132,12 @@ impl<'db> Bindings<'db> {
mut self,
db: &'db dyn Db,
argument_types: &CallArguments<'_, 'db>,
call_expression_tcx: &TypeContext<'db>,
) -> Result<Self, CallError<'db>> {
for element in &mut self.elements {
if let Some(mut updated_argument_forms) = element.check_types(db, argument_types) {
if let Some(mut updated_argument_forms) =
element.check_types(db, argument_types, call_expression_tcx)
{
// If this element returned a new set of argument forms (indicating successful
// argument type expansion), update the `Bindings` with these forms.
updated_argument_forms.shrink_to_fit();
@ -1281,6 +1287,7 @@ impl<'db> CallableBinding<'db> {
&mut self,
db: &'db dyn Db,
argument_types: &CallArguments<'_, 'db>,
call_expression_tcx: &TypeContext<'db>,
) -> Option<ArgumentForms> {
// If this callable is a bound method, prepend the self instance onto the arguments list
// before checking.
@ -1293,7 +1300,7 @@ impl<'db> CallableBinding<'db> {
// still perform type checking for non-overloaded function to provide better user
// experience.
if let [overload] = self.overloads.as_mut_slice() {
overload.check_types(db, argument_types.as_ref());
overload.check_types(db, argument_types.as_ref(), call_expression_tcx);
}
return None;
}
@ -1301,7 +1308,7 @@ impl<'db> CallableBinding<'db> {
// If only one candidate overload remains, it is the winning match. Evaluate it as
// a regular (non-overloaded) call.
self.matching_overload_index = Some(index);
self.overloads[index].check_types(db, argument_types.as_ref());
self.overloads[index].check_types(db, argument_types.as_ref(), call_expression_tcx);
return None;
}
MatchingOverloadIndex::Multiple(indexes) => {
@ -1313,7 +1320,7 @@ impl<'db> CallableBinding<'db> {
// Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
// whether it is compatible with the supplied argument list.
for (_, overload) in self.matching_overloads_mut() {
overload.check_types(db, argument_types.as_ref());
overload.check_types(db, argument_types.as_ref(), call_expression_tcx);
}
match self.matching_overload_index() {
@ -1430,7 +1437,7 @@ impl<'db> CallableBinding<'db> {
merged_argument_forms.merge(&argument_forms);
for (_, overload) in self.matching_overloads_mut() {
overload.check_types(db, expanded_arguments);
overload.check_types(db, expanded_arguments, call_expression_tcx);
}
let return_type = match self.matching_overload_index() {
@ -2243,6 +2250,7 @@ struct ArgumentTypeChecker<'a, 'db> {
arguments: &'a CallArguments<'a, 'db>,
argument_matches: &'a [MatchedArgument<'db>],
parameter_tys: &'a mut [Option<Type<'db>>],
call_expression_tcx: &'a TypeContext<'db>,
errors: &'a mut Vec<BindingError<'db>>,
specialization: Option<Specialization<'db>>,
@ -2256,6 +2264,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
arguments: &'a CallArguments<'a, 'db>,
argument_matches: &'a [MatchedArgument<'db>],
parameter_tys: &'a mut [Option<Type<'db>>],
call_expression_tcx: &'a TypeContext<'db>,
errors: &'a mut Vec<BindingError<'db>>,
) -> Self {
Self {
@ -2264,6 +2273,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
arguments,
argument_matches,
parameter_tys,
call_expression_tcx,
errors,
specialization: None,
inherited_specialization: None,
@ -2304,8 +2314,20 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
return;
}
let parameters = self.signature.parameters();
let mut builder = SpecializationBuilder::new(self.db);
// Note that we infer the annotated type _before_ the arguments if this call is part of
// an annotated assignment, to closer match the order of any unions written in the type
// annotation.
if let Some(return_ty) = self.signature.return_ty
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
{
// Ignore any specialization errors here, because the type context is only used to
// optionally widen the return type.
let _ = builder.infer(return_ty, call_expression_tcx);
}
let parameters = self.signature.parameters();
for (argument_index, adjusted_argument_index, _, argument_type) in
self.enumerate_argument_types()
{
@ -2316,6 +2338,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
let Some(expected_type) = parameter.annotated_type() else {
continue;
};
if let Err(error) = builder.infer(
expected_type,
variadic_argument_type.unwrap_or(argument_type),
@ -2327,6 +2350,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
}
}
}
self.specialization = self.signature.generic_context.map(|gc| builder.build(gc));
self.inherited_specialization = self.signature.inherited_generic_context.map(|gc| {
// The inherited generic context is used when inferring the specialization of a generic
@ -2688,13 +2712,19 @@ impl<'db> Binding<'db> {
self.argument_matches = matcher.finish();
}
fn check_types(&mut self, db: &'db dyn Db, arguments: &CallArguments<'_, 'db>) {
fn check_types(
&mut self,
db: &'db dyn Db,
arguments: &CallArguments<'_, 'db>,
call_expression_tcx: &TypeContext<'db>,
) {
let mut checker = ArgumentTypeChecker::new(
db,
&self.signature,
arguments,
&self.argument_matches,
&mut self.parameter_tys,
call_expression_tcx,
&mut self.errors,
);

View file

@ -352,7 +352,7 @@ struct ExpressionWithContext<'db> {
/// more precise inference results, aka "bidirectional type inference".
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
pub(crate) struct TypeContext<'db> {
annotation: Option<Type<'db>>,
pub(crate) annotation: Option<Type<'db>>,
}
impl<'db> TypeContext<'db> {

View file

@ -5775,7 +5775,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_call_expression(
&mut self,
call_expression: &ast::ExprCall,
_tcx: TypeContext<'db>,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprCall {
range: _,
@ -5955,7 +5955,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
let mut bindings = match bindings.check_types(self.db(), &call_arguments) {
let mut bindings = match bindings.check_types(self.db(), &call_arguments, &tcx) {
Ok(bindings) => bindings,
Err(CallError(_, bindings)) => {
bindings.report_diagnostics(&self.context, call_expression.into());
@ -8521,7 +8521,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let binding = Binding::single(value_ty, generic_context.signature(self.db()));
let bindings = match Bindings::from(binding)
.match_parameters(self.db(), &call_argument_types)
.check_types(self.db(), &call_argument_types)
.check_types(self.db(), &call_argument_types, &TypeContext::default())
{
Ok(bindings) => bindings,
Err(CallError(_, bindings)) => {