mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-29 13:24:57 +00:00
use type context for inference of generic function calls
This commit is contained in:
parent
44fc87f491
commit
5f294f9f2e
5 changed files with 90 additions and 15 deletions
|
@ -234,3 +234,48 @@ reveal_type(x) # revealed: Foo
|
||||||
x: int = 1
|
x: int = 1
|
||||||
reveal_type(x) # revealed: Literal[1]
|
reveal_type(x) # revealed: Literal[1]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Annotations influence generic call inference
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[environment]
|
||||||
|
python-version = "3.12"
|
||||||
|
```
|
||||||
|
|
||||||
|
```py
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
def f[T](x: T) -> list[T]:
|
||||||
|
return [x]
|
||||||
|
|
||||||
|
a = f("a")
|
||||||
|
reveal_type(a) # revealed: list[Literal["a"]]
|
||||||
|
|
||||||
|
b: list[int | Literal["a"]] = f("a")
|
||||||
|
reveal_type(b) # revealed: list[int | Literal["a"]]
|
||||||
|
|
||||||
|
c: list[int | str] = f("a")
|
||||||
|
reveal_type(c) # revealed: list[int | str]
|
||||||
|
|
||||||
|
d: list[int | tuple[int, int]] = f((1, 2))
|
||||||
|
reveal_type(d) # revealed: list[int | tuple[int, int]]
|
||||||
|
|
||||||
|
e: list[int] = f(True)
|
||||||
|
reveal_type(e) # revealed: list[int]
|
||||||
|
|
||||||
|
# TODO: the RHS should be inferred as `list[Literal["a"]]` here
|
||||||
|
# error: [invalid-assignment] "Object of type `list[int | Literal["a"]]` is not assignable to `list[int]`"
|
||||||
|
g: list[int] = f("a")
|
||||||
|
|
||||||
|
# error: [invalid-assignment] "Object of type `list[Literal["a"]]` is not assignable to `tuple[int]`"
|
||||||
|
h: tuple[int] = f("a")
|
||||||
|
|
||||||
|
def f2[T: int](x: T) -> T:
|
||||||
|
return x
|
||||||
|
|
||||||
|
i: int = f2(True)
|
||||||
|
reveal_type(i) # revealed: int
|
||||||
|
|
||||||
|
j: int | str = f2(True)
|
||||||
|
reveal_type(j) # revealed: Literal[True]
|
||||||
|
```
|
||||||
|
|
|
@ -4805,7 +4805,7 @@ impl<'db> Type<'db> {
|
||||||
) -> Result<Bindings<'db>, CallError<'db>> {
|
) -> Result<Bindings<'db>, CallError<'db>> {
|
||||||
self.bindings(db)
|
self.bindings(db)
|
||||||
.match_parameters(db, argument_types)
|
.match_parameters(db, argument_types)
|
||||||
.check_types(db, argument_types)
|
.check_types(db, argument_types, &TypeContext::default())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Look up a dunder method on the meta-type of `self` and call it.
|
/// Look up a dunder method on the meta-type of `self` and call it.
|
||||||
|
@ -4854,7 +4854,7 @@ impl<'db> Type<'db> {
|
||||||
let bindings = dunder_callable
|
let bindings = dunder_callable
|
||||||
.bindings(db)
|
.bindings(db)
|
||||||
.match_parameters(db, argument_types)
|
.match_parameters(db, argument_types)
|
||||||
.check_types(db, argument_types)?;
|
.check_types(db, argument_types, &TypeContext::default())?;
|
||||||
if boundness == Boundness::PossiblyUnbound {
|
if boundness == Boundness::PossiblyUnbound {
|
||||||
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
|
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,8 +32,8 @@ use crate::types::tuple::{TupleLength, TupleType};
|
||||||
use crate::types::{
|
use crate::types::{
|
||||||
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
|
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
|
||||||
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
|
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
|
||||||
TrackedConstraintSet, TypeAliasType, TypeMapping, UnionType, WrapperDescriptorKind, enums,
|
TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionType,
|
||||||
ide_support, todo_type,
|
WrapperDescriptorKind, enums, ide_support, todo_type,
|
||||||
};
|
};
|
||||||
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
|
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
|
||||||
use ruff_python_ast::{self as ast, PythonVersion};
|
use ruff_python_ast::{self as ast, PythonVersion};
|
||||||
|
@ -122,6 +122,9 @@ impl<'db> Bindings<'db> {
|
||||||
/// You must provide an `argument_types` that was created from the same `arguments` that you
|
/// You must provide an `argument_types` that was created from the same `arguments` that you
|
||||||
/// provided to [`match_parameters`][Self::match_parameters].
|
/// provided to [`match_parameters`][Self::match_parameters].
|
||||||
///
|
///
|
||||||
|
/// The type context of the call expression is also used to infer the specialization of generic
|
||||||
|
/// calls.
|
||||||
|
///
|
||||||
/// We update the bindings to include the return type of the call, the bound types for all
|
/// We update the bindings to include the return type of the call, the bound types for all
|
||||||
/// parameters, and any errors resulting from binding the call, all for each union element and
|
/// parameters, and any errors resulting from binding the call, all for each union element and
|
||||||
/// overload (if any).
|
/// overload (if any).
|
||||||
|
@ -129,9 +132,12 @@ impl<'db> Bindings<'db> {
|
||||||
mut self,
|
mut self,
|
||||||
db: &'db dyn Db,
|
db: &'db dyn Db,
|
||||||
argument_types: &CallArguments<'_, 'db>,
|
argument_types: &CallArguments<'_, 'db>,
|
||||||
|
call_expression_tcx: &TypeContext<'db>,
|
||||||
) -> Result<Self, CallError<'db>> {
|
) -> Result<Self, CallError<'db>> {
|
||||||
for element in &mut self.elements {
|
for element in &mut self.elements {
|
||||||
if let Some(mut updated_argument_forms) = element.check_types(db, argument_types) {
|
if let Some(mut updated_argument_forms) =
|
||||||
|
element.check_types(db, argument_types, call_expression_tcx)
|
||||||
|
{
|
||||||
// If this element returned a new set of argument forms (indicating successful
|
// If this element returned a new set of argument forms (indicating successful
|
||||||
// argument type expansion), update the `Bindings` with these forms.
|
// argument type expansion), update the `Bindings` with these forms.
|
||||||
updated_argument_forms.shrink_to_fit();
|
updated_argument_forms.shrink_to_fit();
|
||||||
|
@ -1281,6 +1287,7 @@ impl<'db> CallableBinding<'db> {
|
||||||
&mut self,
|
&mut self,
|
||||||
db: &'db dyn Db,
|
db: &'db dyn Db,
|
||||||
argument_types: &CallArguments<'_, 'db>,
|
argument_types: &CallArguments<'_, 'db>,
|
||||||
|
call_expression_tcx: &TypeContext<'db>,
|
||||||
) -> Option<ArgumentForms> {
|
) -> Option<ArgumentForms> {
|
||||||
// If this callable is a bound method, prepend the self instance onto the arguments list
|
// If this callable is a bound method, prepend the self instance onto the arguments list
|
||||||
// before checking.
|
// before checking.
|
||||||
|
@ -1293,7 +1300,7 @@ impl<'db> CallableBinding<'db> {
|
||||||
// still perform type checking for non-overloaded function to provide better user
|
// still perform type checking for non-overloaded function to provide better user
|
||||||
// experience.
|
// experience.
|
||||||
if let [overload] = self.overloads.as_mut_slice() {
|
if let [overload] = self.overloads.as_mut_slice() {
|
||||||
overload.check_types(db, argument_types.as_ref());
|
overload.check_types(db, argument_types.as_ref(), call_expression_tcx);
|
||||||
}
|
}
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
@ -1301,7 +1308,7 @@ impl<'db> CallableBinding<'db> {
|
||||||
// If only one candidate overload remains, it is the winning match. Evaluate it as
|
// If only one candidate overload remains, it is the winning match. Evaluate it as
|
||||||
// a regular (non-overloaded) call.
|
// a regular (non-overloaded) call.
|
||||||
self.matching_overload_index = Some(index);
|
self.matching_overload_index = Some(index);
|
||||||
self.overloads[index].check_types(db, argument_types.as_ref());
|
self.overloads[index].check_types(db, argument_types.as_ref(), call_expression_tcx);
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
MatchingOverloadIndex::Multiple(indexes) => {
|
MatchingOverloadIndex::Multiple(indexes) => {
|
||||||
|
@ -1313,7 +1320,7 @@ impl<'db> CallableBinding<'db> {
|
||||||
// Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
|
// Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
|
||||||
// whether it is compatible with the supplied argument list.
|
// whether it is compatible with the supplied argument list.
|
||||||
for (_, overload) in self.matching_overloads_mut() {
|
for (_, overload) in self.matching_overloads_mut() {
|
||||||
overload.check_types(db, argument_types.as_ref());
|
overload.check_types(db, argument_types.as_ref(), call_expression_tcx);
|
||||||
}
|
}
|
||||||
|
|
||||||
match self.matching_overload_index() {
|
match self.matching_overload_index() {
|
||||||
|
@ -1430,7 +1437,7 @@ impl<'db> CallableBinding<'db> {
|
||||||
merged_argument_forms.merge(&argument_forms);
|
merged_argument_forms.merge(&argument_forms);
|
||||||
|
|
||||||
for (_, overload) in self.matching_overloads_mut() {
|
for (_, overload) in self.matching_overloads_mut() {
|
||||||
overload.check_types(db, expanded_arguments);
|
overload.check_types(db, expanded_arguments, call_expression_tcx);
|
||||||
}
|
}
|
||||||
|
|
||||||
let return_type = match self.matching_overload_index() {
|
let return_type = match self.matching_overload_index() {
|
||||||
|
@ -2243,6 +2250,7 @@ struct ArgumentTypeChecker<'a, 'db> {
|
||||||
arguments: &'a CallArguments<'a, 'db>,
|
arguments: &'a CallArguments<'a, 'db>,
|
||||||
argument_matches: &'a [MatchedArgument<'db>],
|
argument_matches: &'a [MatchedArgument<'db>],
|
||||||
parameter_tys: &'a mut [Option<Type<'db>>],
|
parameter_tys: &'a mut [Option<Type<'db>>],
|
||||||
|
call_expression_tcx: &'a TypeContext<'db>,
|
||||||
errors: &'a mut Vec<BindingError<'db>>,
|
errors: &'a mut Vec<BindingError<'db>>,
|
||||||
|
|
||||||
specialization: Option<Specialization<'db>>,
|
specialization: Option<Specialization<'db>>,
|
||||||
|
@ -2256,6 +2264,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
||||||
arguments: &'a CallArguments<'a, 'db>,
|
arguments: &'a CallArguments<'a, 'db>,
|
||||||
argument_matches: &'a [MatchedArgument<'db>],
|
argument_matches: &'a [MatchedArgument<'db>],
|
||||||
parameter_tys: &'a mut [Option<Type<'db>>],
|
parameter_tys: &'a mut [Option<Type<'db>>],
|
||||||
|
call_expression_tcx: &'a TypeContext<'db>,
|
||||||
errors: &'a mut Vec<BindingError<'db>>,
|
errors: &'a mut Vec<BindingError<'db>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
@ -2264,6 +2273,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
||||||
arguments,
|
arguments,
|
||||||
argument_matches,
|
argument_matches,
|
||||||
parameter_tys,
|
parameter_tys,
|
||||||
|
call_expression_tcx,
|
||||||
errors,
|
errors,
|
||||||
specialization: None,
|
specialization: None,
|
||||||
inherited_specialization: None,
|
inherited_specialization: None,
|
||||||
|
@ -2304,8 +2314,20 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let parameters = self.signature.parameters();
|
|
||||||
let mut builder = SpecializationBuilder::new(self.db);
|
let mut builder = SpecializationBuilder::new(self.db);
|
||||||
|
|
||||||
|
// Note that we infer the annotated type _before_ the arguments if this call is part of
|
||||||
|
// an annotated assignment, to closer match the order of any unions written in the type
|
||||||
|
// annotation.
|
||||||
|
if let Some(return_ty) = self.signature.return_ty
|
||||||
|
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
|
||||||
|
{
|
||||||
|
// Ignore any specialization errors here, because the type context is only used to
|
||||||
|
// optionally widen the return type.
|
||||||
|
let _ = builder.infer(return_ty, call_expression_tcx);
|
||||||
|
}
|
||||||
|
|
||||||
|
let parameters = self.signature.parameters();
|
||||||
for (argument_index, adjusted_argument_index, _, argument_type) in
|
for (argument_index, adjusted_argument_index, _, argument_type) in
|
||||||
self.enumerate_argument_types()
|
self.enumerate_argument_types()
|
||||||
{
|
{
|
||||||
|
@ -2316,6 +2338,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
||||||
let Some(expected_type) = parameter.annotated_type() else {
|
let Some(expected_type) = parameter.annotated_type() else {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Err(error) = builder.infer(
|
if let Err(error) = builder.infer(
|
||||||
expected_type,
|
expected_type,
|
||||||
variadic_argument_type.unwrap_or(argument_type),
|
variadic_argument_type.unwrap_or(argument_type),
|
||||||
|
@ -2327,6 +2350,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.specialization = self.signature.generic_context.map(|gc| builder.build(gc));
|
self.specialization = self.signature.generic_context.map(|gc| builder.build(gc));
|
||||||
self.inherited_specialization = self.signature.inherited_generic_context.map(|gc| {
|
self.inherited_specialization = self.signature.inherited_generic_context.map(|gc| {
|
||||||
// The inherited generic context is used when inferring the specialization of a generic
|
// The inherited generic context is used when inferring the specialization of a generic
|
||||||
|
@ -2688,13 +2712,19 @@ impl<'db> Binding<'db> {
|
||||||
self.argument_matches = matcher.finish();
|
self.argument_matches = matcher.finish();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_types(&mut self, db: &'db dyn Db, arguments: &CallArguments<'_, 'db>) {
|
fn check_types(
|
||||||
|
&mut self,
|
||||||
|
db: &'db dyn Db,
|
||||||
|
arguments: &CallArguments<'_, 'db>,
|
||||||
|
call_expression_tcx: &TypeContext<'db>,
|
||||||
|
) {
|
||||||
let mut checker = ArgumentTypeChecker::new(
|
let mut checker = ArgumentTypeChecker::new(
|
||||||
db,
|
db,
|
||||||
&self.signature,
|
&self.signature,
|
||||||
arguments,
|
arguments,
|
||||||
&self.argument_matches,
|
&self.argument_matches,
|
||||||
&mut self.parameter_tys,
|
&mut self.parameter_tys,
|
||||||
|
call_expression_tcx,
|
||||||
&mut self.errors,
|
&mut self.errors,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -352,7 +352,7 @@ struct ExpressionWithContext<'db> {
|
||||||
/// more precise inference results, aka "bidirectional type inference".
|
/// more precise inference results, aka "bidirectional type inference".
|
||||||
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
|
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
|
||||||
pub(crate) struct TypeContext<'db> {
|
pub(crate) struct TypeContext<'db> {
|
||||||
annotation: Option<Type<'db>>,
|
pub(crate) annotation: Option<Type<'db>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'db> TypeContext<'db> {
|
impl<'db> TypeContext<'db> {
|
||||||
|
|
|
@ -5775,7 +5775,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||||
fn infer_call_expression(
|
fn infer_call_expression(
|
||||||
&mut self,
|
&mut self,
|
||||||
call_expression: &ast::ExprCall,
|
call_expression: &ast::ExprCall,
|
||||||
_tcx: TypeContext<'db>,
|
tcx: TypeContext<'db>,
|
||||||
) -> Type<'db> {
|
) -> Type<'db> {
|
||||||
let ast::ExprCall {
|
let ast::ExprCall {
|
||||||
range: _,
|
range: _,
|
||||||
|
@ -5955,7 +5955,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut bindings = match bindings.check_types(self.db(), &call_arguments) {
|
let mut bindings = match bindings.check_types(self.db(), &call_arguments, &tcx) {
|
||||||
Ok(bindings) => bindings,
|
Ok(bindings) => bindings,
|
||||||
Err(CallError(_, bindings)) => {
|
Err(CallError(_, bindings)) => {
|
||||||
bindings.report_diagnostics(&self.context, call_expression.into());
|
bindings.report_diagnostics(&self.context, call_expression.into());
|
||||||
|
@ -8521,7 +8521,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||||
let binding = Binding::single(value_ty, generic_context.signature(self.db()));
|
let binding = Binding::single(value_ty, generic_context.signature(self.db()));
|
||||||
let bindings = match Bindings::from(binding)
|
let bindings = match Bindings::from(binding)
|
||||||
.match_parameters(self.db(), &call_argument_types)
|
.match_parameters(self.db(), &call_argument_types)
|
||||||
.check_types(self.db(), &call_argument_types)
|
.check_types(self.db(), &call_argument_types, &TypeContext::default())
|
||||||
{
|
{
|
||||||
Ok(bindings) => bindings,
|
Ok(bindings) => bindings,
|
||||||
Err(CallError(_, bindings)) => {
|
Err(CallError(_, bindings)) => {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue