WIP: use annotated parameters as type context

This commit is contained in:
Ibraheem Ahmed 2025-09-26 01:56:59 -04:00
parent e66a872c14
commit e33bf726d7
6 changed files with 271 additions and 40 deletions

View file

@ -1682,3 +1682,57 @@ def _(arg: tuple[A | B, Any]):
reveal_type(f(arg)) # revealed: Unknown
reveal_type(f(*(arg,))) # revealed: Unknown
```
## Bi-directional Type Inference
Type inference accounts for the type of each overload.
```py
from typing import TypedDict, overload
class T(TypedDict):
x: int
@overload
def f(a: list[T], b: int):
...
@overload
def f(a: list[dict[str, int]], b: str):
...
def f(a: list[dict[str, int]] | list[T], b: int | str):
...
def int_or_str() -> int | str:
return 1
f([{ "x": 1 }], int_or_str())
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `T` constructor"
# error: [invalid-key] "Invalid key access on TypedDict `T`: Unknown key "y""
f([{ "y": 1 }], int_or_str())
```
Non-matching overloads do not produce errors:
```py
from typing import TypedDict, overload
class T(TypedDict):
x: int
@overload
def f(a: T, b: int):
...
@overload
def f(a: dict[str, int], b: str):
...
def f(a: T | dict[str, int], b: int | str):
...
# TODO
f({ "y": 1 }, "a")
```

View file

@ -152,7 +152,7 @@ Person(name="Alice")
# error: [missing-typed-dict-key] "Missing required key 'age' in TypedDict `Person` constructor"
Person({"name": "Alice"})
# TODO: this should be an error, similar to the above
# error: [missing-typed-dict-key] "Missing required key 'age' in TypedDict `Person` constructor"
accepts_person({"name": "Alice"})
# TODO: this should be an error, similar to the above
house.owner = {"name": "Alice"}
@ -171,7 +171,7 @@ Person(name=None, age=30)
# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Person`: value of type `None`"
Person({"name": None, "age": 30})
# TODO: this should be an error, similar to the above
# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Person`: value of type `None`"
accepts_person({"name": None, "age": 30})
# TODO: this should be an error, similar to the above
house.owner = {"name": None, "age": 30}
@ -190,7 +190,7 @@ Person(name="Alice", age=30, extra=True)
# error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extra""
Person({"name": "Alice", "age": 30, "extra": True})
# TODO: this should be an error
# error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extra""
accepts_person({"name": "Alice", "age": 30, "extra": True})
# TODO: this should be an error
house.owner = {"name": "Alice", "age": 30, "extra": True}

View file

@ -1083,7 +1083,11 @@ impl<'db> InnerIntersectionBuilder<'db> {
// don't need to worry about finding any particular constraint more than once.
let constraints = constraints.elements(db);
let mut positive_constraint_count = 0;
for positive in &self.positive {
for (i, positive) in self.positive.iter().enumerate() {
if i == typevar_index {
continue;
}
// This linear search should be fine as long as we don't encounter typevars with
// thousands of constraints.
positive_constraint_count += constraints

View file

@ -33,10 +33,10 @@ use crate::types::{
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionType,
WrapperDescriptorKind, enums, ide_support, todo_type,
WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, todo_type,
};
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
use ruff_python_ast::{self as ast, PythonVersion};
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
/// Binding information for a possible union of callables. At a call site, the arguments must be
/// compatible with _all_ of the types in the union for the call to be valid.
@ -1732,7 +1732,7 @@ impl<'db> CallableBinding<'db> {
}
/// Returns the index of the matching overload in the form of [`MatchingOverloadIndex`].
fn matching_overload_index(&self) -> MatchingOverloadIndex {
pub(crate) fn matching_overload_index(&self) -> MatchingOverloadIndex {
let mut matching_overloads = self.matching_overloads();
match matching_overloads.next() {
None => MatchingOverloadIndex::None,
@ -1750,6 +1750,11 @@ impl<'db> CallableBinding<'db> {
}
}
/// Returns all overloads for this call binding, including overloads that did not match.
pub(crate) fn overloads(&self) -> &[Binding<'db>] {
self.overloads.as_slice()
}
/// Returns an iterator over all the overloads that matched for this call binding.
pub(crate) fn matching_overloads(&self) -> impl Iterator<Item = (usize, &Binding<'db>)> {
self.overloads
@ -1982,7 +1987,7 @@ enum OverloadCallReturnType<'db> {
}
#[derive(Debug)]
enum MatchingOverloadIndex {
pub(crate) enum MatchingOverloadIndex {
/// No matching overloads found.
None,
@ -2464,10 +2469,18 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
if let Some(return_ty) = self.signature.return_ty
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
{
match call_expression_tcx {
// A type variable is not a useful type-context for expression inference, and applying it
// to the return type can lead to confusing unions in nested generic calls.
Type::TypeVar(_) => {}
_ => {
// Ignore any specialization errors here, because the type context is only used to
// optionally widen the return type.
let _ = builder.infer(return_ty, call_expression_tcx);
}
}
}
let parameters = self.signature.parameters();
for (argument_index, adjusted_argument_index, _, argument_type) in
@ -3278,6 +3291,23 @@ impl<'db> BindingError<'db> {
return;
};
// Re-infer the argument type of call expressions, ignoring the type context for more
// precise error messages.
let provided_ty = match Self::get_argument_node(node, *argument_index) {
None => *provided_ty,
// Ignore starred arguments, as those are difficult to re-infer.
Some(
ast::ArgOrKeyword::Arg(ast::Expr::Starred(_))
| ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }),
) => *provided_ty,
Some(
ast::ArgOrKeyword::Arg(value)
| ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }),
) => infer_isolated_expression(context.db(), context.scope(), value),
};
let provided_ty_display = provided_ty.display(context.db());
let expected_ty_display = expected_ty.display(context.db());
@ -3613,22 +3643,29 @@ impl<'db> BindingError<'db> {
}
}
fn get_node(node: ast::AnyNodeRef, argument_index: Option<usize>) -> ast::AnyNodeRef {
fn get_node(node: ast::AnyNodeRef<'_>, argument_index: Option<usize>) -> ast::AnyNodeRef<'_> {
// If we have a Call node and an argument index, report the diagnostic on the correct
// argument node; otherwise, report it on the entire provided node.
match Self::get_argument_node(node, argument_index) {
Some(ast::ArgOrKeyword::Arg(expr)) => expr.into(),
Some(ast::ArgOrKeyword::Keyword(expr)) => expr.into(),
None => node,
}
}
fn get_argument_node(
node: ast::AnyNodeRef<'_>,
argument_index: Option<usize>,
) -> Option<ArgOrKeyword<'_>> {
match (node, argument_index) {
(ast::AnyNodeRef::ExprCall(call_node), Some(argument_index)) => {
match call_node
(ast::AnyNodeRef::ExprCall(call_node), Some(argument_index)) => Some(
call_node
.arguments
.arguments_source_order()
.nth(argument_index)
.expect("argument index should not be out of range")
{
ast::ArgOrKeyword::Arg(expr) => expr.into(),
ast::ArgOrKeyword::Keyword(keyword) => keyword.into(),
}
}
_ => node,
.expect("argument index should not be out of range"),
),
_ => None,
}
}
}

View file

@ -1975,7 +1975,7 @@ pub(super) fn report_invalid_assignment<'db>(
if let DefinitionKind::AnnotatedAssignment(annotated_assignment) = definition.kind(context.db())
&& let Some(value) = annotated_assignment.value(context.module())
{
// Re-infer the RHS of the annotated assignment, ignoring the type context, for more precise
// Re-infer the RHS of the annotated assignment, ignoring the type context for more precise
// error messages.
source_ty = infer_isolated_expression(context.db(), definition.scope(context.db()), value);
}

View file

@ -1,4 +1,4 @@
use std::iter;
use std::{iter, mem};
use itertools::{Either, Itertools};
use ruff_db::diagnostic::{Annotation, DiagnosticId, Severity};
@ -44,6 +44,7 @@ use crate::semantic_index::symbol::{ScopedSymbolId, Symbol};
use crate::semantic_index::{
ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, place_table,
};
use crate::types::call::bind::MatchingOverloadIndex;
use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind};
use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator};
use crate::types::context::{InNoTypeCheck, InferContext};
@ -258,6 +259,8 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> {
/// is a stub file but we're still in a non-deferred region.
deferred_state: DeferredExpressionState,
multi_inference_state: MultiInferenceState,
/// For function definitions, the undecorated type of the function.
undecorated_type: Option<Type<'db>>,
@ -288,10 +291,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
context: InferContext::new(db, scope, module),
index,
region,
scope,
return_types_and_ranges: vec![],
called_functions: FxHashSet::default(),
deferred_state: DeferredExpressionState::None,
scope,
multi_inference_state: MultiInferenceState::Panic,
expressions: FxHashMap::default(),
bindings: VecMap::default(),
declarations: VecMap::default(),
@ -4917,6 +4921,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.infer_expression(expression, TypeContext::default())
}
/// Infer the argument types for a single binding.
fn infer_argument_types<'a>(
&mut self,
ast_arguments: &ast::Arguments,
@ -4926,22 +4931,113 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
debug_assert!(
ast_arguments.len() == arguments.len() && arguments.len() == argument_forms.len()
);
let iter = (arguments.iter_mut())
.zip(argument_forms.iter().copied())
.zip(ast_arguments.arguments_source_order());
for (((_, argument_type), form), arg_or_keyword) in iter {
let argument = match arg_or_keyword {
// We already inferred the type of splatted arguments.
let iter = itertools::izip!(
arguments.iter_mut(),
argument_forms.iter().copied(),
ast_arguments.arguments_source_order()
);
for ((_, argument_type), argument_form, ast_argument) in iter {
let argument = match ast_argument {
// Splatted arguments are inferred before parameter matching to
// determine their length.
ast::ArgOrKeyword::Arg(ast::Expr::Starred(_))
| ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }) => continue,
ast::ArgOrKeyword::Arg(arg) => arg,
ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value,
};
let ty = self.infer_argument_type(argument, form, TypeContext::default());
let ty = self.infer_argument_type(argument, argument_form, TypeContext::default());
*argument_type = Some(ty);
}
}
/// Infer the argument types for multiple potential binding and overloads.
fn infer_all_argument_types<'a>(
&mut self,
ast_arguments: &ast::Arguments,
arguments: &mut CallArguments<'a, 'db>,
bindings: &Bindings<'db>,
) {
debug_assert!(
ast_arguments.len() == arguments.len()
&& arguments.len() == bindings.argument_forms().len()
);
let iter = itertools::izip!(
0..,
arguments.iter_mut(),
bindings.argument_forms().iter().copied(),
ast_arguments.arguments_source_order()
);
for (argument_index, (_, argument_type), argument_form, ast_argument) in iter {
let ast_argument = match ast_argument {
// Splatted arguments are inferred before parameter matching to
// determine their length.
//
// TODO: Re-infer splatted arguments with their type context.
ast::ArgOrKeyword::Arg(ast::Expr::Starred(_))
| ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }) => continue,
ast::ArgOrKeyword::Arg(arg) => arg,
ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value,
};
// Type-form arguments are inferred without type context, so we can infer the argument type directly.
if let Some(ParameterForm::Type) = argument_form {
*argument_type = Some(self.infer_type_expression(ast_argument));
continue;
}
// Otherwise, we infer the type of each argument once for each matching overload signature,
// with the given annotated type as type context.
for binding in bindings {
let overloads = match binding.matching_overload_index() {
MatchingOverloadIndex::Single(_) | MatchingOverloadIndex::Multiple(_) => {
Either::Right(binding.matching_overloads())
}
// If there is a single overload that does not match, we still infer the
// argument types for better diagnostics.
MatchingOverloadIndex::None => match binding.overloads() {
[overload] => Either::Left([(0, overload)].into_iter()),
_ => continue,
},
};
// Each type is a valid inference for the given argument, and we may require different
// permutations of argument types to correctly perform argument expansion during overload
// evaluation, so we take the intersection of all the types we inferred for each argument.
//
// We update the state of the inference builder to apply intersections to all nested expressions.
let old_multi_inference_state = mem::replace(
&mut self.multi_inference_state,
MultiInferenceState::Intersect,
);
for (_, overload) in overloads.into_iter() {
let argument_matches = &overload.argument_matches()[argument_index];
let [parameter_index] = argument_matches.parameters.as_slice() else {
continue;
};
let parameter_type =
overload.signature.parameters()[*parameter_index].annotated_type();
self.infer_expression_impl(ast_argument, TypeContext::new(parameter_type));
}
*argument_type = self.try_expression_type(ast_argument);
// Restore the multi-inference state.
self.multi_inference_state = old_multi_inference_state;
}
}
}
fn infer_argument_type(
&mut self,
ast_argument: &ast::Expr,
@ -5022,12 +5118,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
types.expression_type(expression)
}
/// Infer the type of an expression.
fn infer_expression_impl(
&mut self,
expression: &ast::Expr,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ty = match expression {
let ty = self.infer_expression_impl_no_store(expression, tcx);
self.store_expression_type(expression, ty);
ty
}
/// Infer the type of an expression without storing the result.
fn infer_expression_impl_no_store(
&mut self,
expression: &ast::Expr,
tcx: TypeContext<'db>,
) -> Type<'db> {
match expression {
ast::Expr::NoneLiteral(ast::ExprNoneLiteral {
range: _,
node_index: _,
@ -5068,12 +5176,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast::Expr::IpyEscapeCommand(_) => {
todo_type!("Ipy escape command support")
}
};
self.store_expression_type(expression, ty);
ty
}
}
fn store_expression_type(&mut self, expression: &ast::Expr, ty: Type<'db>) {
if self.deferred_state.in_string_annotation() {
// Avoid storing the type of expressions that are part of a string annotation because
@ -5081,9 +5186,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// on the string expression itself that represents the annotation.
return;
}
let db = self.db();
match self.multi_inference_state {
MultiInferenceState::Panic => {
let previous = self.expressions.insert(expression.into(), ty);
assert_eq!(previous, None);
}
MultiInferenceState::Intersect => {
self.expressions
.entry(expression.into())
.and_modify(|current| {
*current = IntersectionType::from_elements(db, [*current, ty])
})
.or_insert(ty);
}
}
}
fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type<'db> {
let ast::ExprNumberLiteral {
@ -5319,6 +5439,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
// TODO: Fall-back to `Unknown` so that we don't eagerly error when matching against a
// potential overload.
validate_typed_dict_dict_literal(
&self.context,
typed_dict,
@ -5980,7 +6102,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let bindings = callable_type
.bindings(self.db())
.match_parameters(self.db(), &call_arguments);
self.infer_argument_types(arguments, &mut call_arguments, bindings.argument_forms());
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings);
// Validate `TypedDict` constructor calls after argument type inference
if let Some(class_literal) = callable_type.into_class_literal() {
@ -9101,6 +9223,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// builder only state
typevar_binding_context: _,
deferred_state: _,
multi_inference_state: _,
called_functions: _,
index: _,
region: _,
@ -9163,6 +9286,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// builder only state
typevar_binding_context: _,
deferred_state: _,
multi_inference_state: _,
called_functions: _,
index: _,
region: _,
@ -9234,6 +9358,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Builder only state
typevar_binding_context: _,
deferred_state: _,
multi_inference_state: _,
called_functions: _,
index: _,
region: _,
@ -9279,6 +9404,17 @@ impl GenericContextError {
}
}
/// Dictates the behavior when an expression is inferred multiple times.
#[derive(Default, Debug, Clone, Copy)]
enum MultiInferenceState {
/// Panic if the expression has already been inferred.
#[default]
Panic,
/// Store the intersection of all types inferred for the expression.
Intersect,
}
/// The deferred state of a specific expression in an inference region.
#[derive(Default, Debug, Clone, Copy)]
enum DeferredExpressionState {