From e33bf726d7d26d351f205ac9b2c6bdbece04bc48 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Fri, 26 Sep 2025 01:56:59 -0400 Subject: [PATCH] WIP: use annotated parameters as type context --- .../resources/mdtest/call/overloads.md | 54 ++++++ .../resources/mdtest/typed_dict.md | 6 +- .../ty_python_semantic/src/types/builder.rs | 6 +- .../ty_python_semantic/src/types/call/bind.rs | 71 ++++++-- .../src/types/diagnostic.rs | 2 +- .../src/types/infer/builder.rs | 172 ++++++++++++++++-- 6 files changed, 271 insertions(+), 40 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/call/overloads.md b/crates/ty_python_semantic/resources/mdtest/call/overloads.md index fcb23dc077..cd1fede988 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/call/overloads.md @@ -1682,3 +1682,57 @@ def _(arg: tuple[A | B, Any]): reveal_type(f(arg)) # revealed: Unknown reveal_type(f(*(arg,))) # revealed: Unknown ``` + +## Bi-directional Type Inference + +Type inference accounts for the type of each overload. + +```py +from typing import TypedDict, overload + +class T(TypedDict): + x: int + +@overload +def f(a: list[T], b: int): + ... + +@overload +def f(a: list[dict[str, int]], b: str): + ... + +def f(a: list[dict[str, int]] | list[T], b: int | str): + ... + +def int_or_str() -> int | str: + return 1 + +f([{ "x": 1 }], int_or_str()) + +# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `T` constructor" +# error: [invalid-key] "Invalid key access on TypedDict `T`: Unknown key "y"" +f([{ "y": 1 }], int_or_str()) +``` + +Non-matching overloads do not produce errors: + +```py +from typing import TypedDict, overload + +class T(TypedDict): + x: int + +@overload +def f(a: T, b: int): + ... + +@overload +def f(a: dict[str, int], b: str): + ... + +def f(a: T | dict[str, int], b: int | str): + ... + +# TODO +f({ "y": 1 }, "a") +``` diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index 3ad8df2eba..6e3a275d60 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -152,7 +152,7 @@ Person(name="Alice") # error: [missing-typed-dict-key] "Missing required key 'age' in TypedDict `Person` constructor" Person({"name": "Alice"}) -# TODO: this should be an error, similar to the above +# error: [missing-typed-dict-key] "Missing required key 'age' in TypedDict `Person` constructor" accepts_person({"name": "Alice"}) # TODO: this should be an error, similar to the above house.owner = {"name": "Alice"} @@ -171,7 +171,7 @@ Person(name=None, age=30) # error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Person`: value of type `None`" Person({"name": None, "age": 30}) -# TODO: this should be an error, similar to the above +# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Person`: value of type `None`" accepts_person({"name": None, "age": 30}) # TODO: this should be an error, similar to the above house.owner = {"name": None, "age": 30} @@ -190,7 +190,7 @@ Person(name="Alice", age=30, extra=True) # error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extra"" Person({"name": "Alice", "age": 30, "extra": True}) -# TODO: this should be an error +# error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extra"" accepts_person({"name": "Alice", "age": 30, "extra": True}) # TODO: this should be an error house.owner = {"name": "Alice", "age": 30, "extra": True} diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index cc1c35790a..eb1394f988 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -1083,7 +1083,11 @@ impl<'db> InnerIntersectionBuilder<'db> { // don't need to worry about finding any particular constraint more than once. let constraints = constraints.elements(db); let mut positive_constraint_count = 0; - for positive in &self.positive { + for (i, positive) in self.positive.iter().enumerate() { + if i == typevar_index { + continue; + } + // This linear search should be fine as long as we don't encounter typevars with // thousands of constraints. positive_constraint_count += constraints diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 59a38e303b..ff20563749 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -33,10 +33,10 @@ use crate::types::{ BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionType, - WrapperDescriptorKind, enums, ide_support, todo_type, + WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, todo_type, }; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; -use ruff_python_ast::{self as ast, PythonVersion}; +use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion}; /// Binding information for a possible union of callables. At a call site, the arguments must be /// compatible with _all_ of the types in the union for the call to be valid. @@ -1732,7 +1732,7 @@ impl<'db> CallableBinding<'db> { } /// Returns the index of the matching overload in the form of [`MatchingOverloadIndex`]. - fn matching_overload_index(&self) -> MatchingOverloadIndex { + pub(crate) fn matching_overload_index(&self) -> MatchingOverloadIndex { let mut matching_overloads = self.matching_overloads(); match matching_overloads.next() { None => MatchingOverloadIndex::None, @@ -1750,6 +1750,11 @@ impl<'db> CallableBinding<'db> { } } + /// Returns all overloads for this call binding, including overloads that did not match. + pub(crate) fn overloads(&self) -> &[Binding<'db>] { + self.overloads.as_slice() + } + /// Returns an iterator over all the overloads that matched for this call binding. pub(crate) fn matching_overloads(&self) -> impl Iterator)> { self.overloads @@ -1982,7 +1987,7 @@ enum OverloadCallReturnType<'db> { } #[derive(Debug)] -enum MatchingOverloadIndex { +pub(crate) enum MatchingOverloadIndex { /// No matching overloads found. None, @@ -2464,9 +2469,17 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { if let Some(return_ty) = self.signature.return_ty && let Some(call_expression_tcx) = self.call_expression_tcx.annotation { - // Ignore any specialization errors here, because the type context is only used to - // optionally widen the return type. - let _ = builder.infer(return_ty, call_expression_tcx); + match call_expression_tcx { + // A type variable is not a useful type-context for expression inference, and applying it + // to the return type can lead to confusing unions in nested generic calls. + Type::TypeVar(_) => {} + + _ => { + // Ignore any specialization errors here, because the type context is only used to + // optionally widen the return type. + let _ = builder.infer(return_ty, call_expression_tcx); + } + } } let parameters = self.signature.parameters(); @@ -3278,6 +3291,23 @@ impl<'db> BindingError<'db> { return; }; + // Re-infer the argument type of call expressions, ignoring the type context for more + // precise error messages. + let provided_ty = match Self::get_argument_node(node, *argument_index) { + None => *provided_ty, + + // Ignore starred arguments, as those are difficult to re-infer. + Some( + ast::ArgOrKeyword::Arg(ast::Expr::Starred(_)) + | ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }), + ) => *provided_ty, + + Some( + ast::ArgOrKeyword::Arg(value) + | ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }), + ) => infer_isolated_expression(context.db(), context.scope(), value), + }; + let provided_ty_display = provided_ty.display(context.db()); let expected_ty_display = expected_ty.display(context.db()); @@ -3613,22 +3643,29 @@ impl<'db> BindingError<'db> { } } - fn get_node(node: ast::AnyNodeRef, argument_index: Option) -> ast::AnyNodeRef { + fn get_node(node: ast::AnyNodeRef<'_>, argument_index: Option) -> ast::AnyNodeRef<'_> { // If we have a Call node and an argument index, report the diagnostic on the correct // argument node; otherwise, report it on the entire provided node. + match Self::get_argument_node(node, argument_index) { + Some(ast::ArgOrKeyword::Arg(expr)) => expr.into(), + Some(ast::ArgOrKeyword::Keyword(expr)) => expr.into(), + None => node, + } + } + + fn get_argument_node( + node: ast::AnyNodeRef<'_>, + argument_index: Option, + ) -> Option> { match (node, argument_index) { - (ast::AnyNodeRef::ExprCall(call_node), Some(argument_index)) => { - match call_node + (ast::AnyNodeRef::ExprCall(call_node), Some(argument_index)) => Some( + call_node .arguments .arguments_source_order() .nth(argument_index) - .expect("argument index should not be out of range") - { - ast::ArgOrKeyword::Arg(expr) => expr.into(), - ast::ArgOrKeyword::Keyword(keyword) => keyword.into(), - } - } - _ => node, + .expect("argument index should not be out of range"), + ), + _ => None, } } } diff --git a/crates/ty_python_semantic/src/types/diagnostic.rs b/crates/ty_python_semantic/src/types/diagnostic.rs index 214eb30749..15df9bed79 100644 --- a/crates/ty_python_semantic/src/types/diagnostic.rs +++ b/crates/ty_python_semantic/src/types/diagnostic.rs @@ -1975,7 +1975,7 @@ pub(super) fn report_invalid_assignment<'db>( if let DefinitionKind::AnnotatedAssignment(annotated_assignment) = definition.kind(context.db()) && let Some(value) = annotated_assignment.value(context.module()) { - // Re-infer the RHS of the annotated assignment, ignoring the type context, for more precise + // Re-infer the RHS of the annotated assignment, ignoring the type context for more precise // error messages. source_ty = infer_isolated_expression(context.db(), definition.scope(context.db()), value); } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 298c3b4644..a47078568b 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -1,4 +1,4 @@ -use std::iter; +use std::{iter, mem}; use itertools::{Either, Itertools}; use ruff_db::diagnostic::{Annotation, DiagnosticId, Severity}; @@ -44,6 +44,7 @@ use crate::semantic_index::symbol::{ScopedSymbolId, Symbol}; use crate::semantic_index::{ ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, place_table, }; +use crate::types::call::bind::MatchingOverloadIndex; use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind}; use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator}; use crate::types::context::{InNoTypeCheck, InferContext}; @@ -258,6 +259,8 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> { /// is a stub file but we're still in a non-deferred region. deferred_state: DeferredExpressionState, + multi_inference_state: MultiInferenceState, + /// For function definitions, the undecorated type of the function. undecorated_type: Option>, @@ -288,10 +291,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { context: InferContext::new(db, scope, module), index, region, + scope, return_types_and_ranges: vec![], called_functions: FxHashSet::default(), deferred_state: DeferredExpressionState::None, - scope, + multi_inference_state: MultiInferenceState::Panic, expressions: FxHashMap::default(), bindings: VecMap::default(), declarations: VecMap::default(), @@ -4917,6 +4921,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.infer_expression(expression, TypeContext::default()) } + /// Infer the argument types for a single binding. fn infer_argument_types<'a>( &mut self, ast_arguments: &ast::Arguments, @@ -4926,22 +4931,113 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { debug_assert!( ast_arguments.len() == arguments.len() && arguments.len() == argument_forms.len() ); - let iter = (arguments.iter_mut()) - .zip(argument_forms.iter().copied()) - .zip(ast_arguments.arguments_source_order()); - for (((_, argument_type), form), arg_or_keyword) in iter { - let argument = match arg_or_keyword { - // We already inferred the type of splatted arguments. + + let iter = itertools::izip!( + arguments.iter_mut(), + argument_forms.iter().copied(), + ast_arguments.arguments_source_order() + ); + + for ((_, argument_type), argument_form, ast_argument) in iter { + let argument = match ast_argument { + // Splatted arguments are inferred before parameter matching to + // determine their length. ast::ArgOrKeyword::Arg(ast::Expr::Starred(_)) | ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }) => continue, + ast::ArgOrKeyword::Arg(arg) => arg, ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value, }; - let ty = self.infer_argument_type(argument, form, TypeContext::default()); + + let ty = self.infer_argument_type(argument, argument_form, TypeContext::default()); *argument_type = Some(ty); } } + /// Infer the argument types for multiple potential binding and overloads. + fn infer_all_argument_types<'a>( + &mut self, + ast_arguments: &ast::Arguments, + arguments: &mut CallArguments<'a, 'db>, + bindings: &Bindings<'db>, + ) { + debug_assert!( + ast_arguments.len() == arguments.len() + && arguments.len() == bindings.argument_forms().len() + ); + + let iter = itertools::izip!( + 0.., + arguments.iter_mut(), + bindings.argument_forms().iter().copied(), + ast_arguments.arguments_source_order() + ); + + for (argument_index, (_, argument_type), argument_form, ast_argument) in iter { + let ast_argument = match ast_argument { + // Splatted arguments are inferred before parameter matching to + // determine their length. + // + // TODO: Re-infer splatted arguments with their type context. + ast::ArgOrKeyword::Arg(ast::Expr::Starred(_)) + | ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }) => continue, + + ast::ArgOrKeyword::Arg(arg) => arg, + ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value, + }; + + // Type-form arguments are inferred without type context, so we can infer the argument type directly. + if let Some(ParameterForm::Type) = argument_form { + *argument_type = Some(self.infer_type_expression(ast_argument)); + continue; + } + + // Otherwise, we infer the type of each argument once for each matching overload signature, + // with the given annotated type as type context. + for binding in bindings { + let overloads = match binding.matching_overload_index() { + MatchingOverloadIndex::Single(_) | MatchingOverloadIndex::Multiple(_) => { + Either::Right(binding.matching_overloads()) + } + + // If there is a single overload that does not match, we still infer the + // argument types for better diagnostics. + MatchingOverloadIndex::None => match binding.overloads() { + [overload] => Either::Left([(0, overload)].into_iter()), + _ => continue, + }, + }; + + // Each type is a valid inference for the given argument, and we may require different + // permutations of argument types to correctly perform argument expansion during overload + // evaluation, so we take the intersection of all the types we inferred for each argument. + // + // We update the state of the inference builder to apply intersections to all nested expressions. + let old_multi_inference_state = mem::replace( + &mut self.multi_inference_state, + MultiInferenceState::Intersect, + ); + + for (_, overload) in overloads.into_iter() { + let argument_matches = &overload.argument_matches()[argument_index]; + let [parameter_index] = argument_matches.parameters.as_slice() else { + continue; + }; + + let parameter_type = + overload.signature.parameters()[*parameter_index].annotated_type(); + + self.infer_expression_impl(ast_argument, TypeContext::new(parameter_type)); + } + + *argument_type = self.try_expression_type(ast_argument); + + // Restore the multi-inference state. + self.multi_inference_state = old_multi_inference_state; + } + } + } + fn infer_argument_type( &mut self, ast_argument: &ast::Expr, @@ -5022,12 +5118,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { types.expression_type(expression) } + /// Infer the type of an expression. fn infer_expression_impl( &mut self, expression: &ast::Expr, tcx: TypeContext<'db>, ) -> Type<'db> { - let ty = match expression { + let ty = self.infer_expression_impl_no_store(expression, tcx); + self.store_expression_type(expression, ty); + ty + } + + /// Infer the type of an expression without storing the result. + fn infer_expression_impl_no_store( + &mut self, + expression: &ast::Expr, + tcx: TypeContext<'db>, + ) -> Type<'db> { + match expression { ast::Expr::NoneLiteral(ast::ExprNoneLiteral { range: _, node_index: _, @@ -5068,12 +5176,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast::Expr::IpyEscapeCommand(_) => { todo_type!("Ipy escape command support") } - }; - - self.store_expression_type(expression, ty); - - ty + } } + fn store_expression_type(&mut self, expression: &ast::Expr, ty: Type<'db>) { if self.deferred_state.in_string_annotation() { // Avoid storing the type of expressions that are part of a string annotation because @@ -5081,8 +5186,23 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // on the string expression itself that represents the annotation. return; } - let previous = self.expressions.insert(expression.into(), ty); - assert_eq!(previous, None); + + let db = self.db(); + + match self.multi_inference_state { + MultiInferenceState::Panic => { + let previous = self.expressions.insert(expression.into(), ty); + assert_eq!(previous, None); + } + MultiInferenceState::Intersect => { + self.expressions + .entry(expression.into()) + .and_modify(|current| { + *current = IntersectionType::from_elements(db, [*current, ty]) + }) + .or_insert(ty); + } + } } fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type<'db> { @@ -5319,6 +5439,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } + // TODO: Fall-back to `Unknown` so that we don't eagerly error when matching against a + // potential overload. validate_typed_dict_dict_literal( &self.context, typed_dict, @@ -5980,7 +6102,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let bindings = callable_type .bindings(self.db()) .match_parameters(self.db(), &call_arguments); - self.infer_argument_types(arguments, &mut call_arguments, bindings.argument_forms()); + self.infer_all_argument_types(arguments, &mut call_arguments, &bindings); // Validate `TypedDict` constructor calls after argument type inference if let Some(class_literal) = callable_type.into_class_literal() { @@ -9101,6 +9223,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // builder only state typevar_binding_context: _, deferred_state: _, + multi_inference_state: _, called_functions: _, index: _, region: _, @@ -9163,6 +9286,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // builder only state typevar_binding_context: _, deferred_state: _, + multi_inference_state: _, called_functions: _, index: _, region: _, @@ -9234,6 +9358,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Builder only state typevar_binding_context: _, deferred_state: _, + multi_inference_state: _, called_functions: _, index: _, region: _, @@ -9279,6 +9404,17 @@ impl GenericContextError { } } +/// Dictates the behavior when an expression is inferred multiple times. +#[derive(Default, Debug, Clone, Copy)] +enum MultiInferenceState { + /// Panic if the expression has already been inferred. + #[default] + Panic, + + /// Store the intersection of all types inferred for the expression. + Intersect, +} + /// The deferred state of a specific expression in an inference region. #[derive(Default, Debug, Clone, Copy)] enum DeferredExpressionState {