diff --git a/crates/ty_python_semantic/resources/mdtest/snapshots/functions.md_-_Generic_functions___L…_-_Inferring_a_bound_ty…_(d50204b9d91b7bd1).snap.new b/crates/ty_python_semantic/resources/mdtest/snapshots/functions.md_-_Generic_functions___L…_-_Inferring_a_bound_ty…_(d50204b9d91b7bd1).snap.new new file mode 100644 index 0000000000..f8111dc53c --- /dev/null +++ b/crates/ty_python_semantic/resources/mdtest/snapshots/functions.md_-_Generic_functions___L…_-_Inferring_a_bound_ty…_(d50204b9d91b7bd1).snap.new @@ -0,0 +1,92 @@ +--- +source: crates/ty_test/src/lib.rs +assertion_line: 421 +expression: snapshot +--- +--- +mdtest name: functions.md - Generic functions: Legacy syntax - Inferring a bound typevar +mdtest path: crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md +--- + +# Python source files + +## mdtest_snippet.py + +``` + 1 | from typing import TypeVar + 2 | from typing_extensions import reveal_type + 3 | + 4 | T = TypeVar("T", bound=int) + 5 | + 6 | def f(x: T) -> T: + 7 | return x + 8 | + 9 | reveal_type(f(1)) # revealed: Literal[1] +10 | reveal_type(f(True)) # revealed: Literal[True] +11 | # error: [invalid-argument-type] +12 | reveal_type(f("string")) # revealed: Unknown +``` + +# Diagnostics + +``` +info[revealed-type]: Revealed type + --> src/mdtest_snippet.py:9:13 + | + 7 | return x + 8 | + 9 | reveal_type(f(1)) # revealed: Literal[1] + | ^^^^ `_T@reveal_type | Literal[1]` +10 | reveal_type(f(True)) # revealed: Literal[True] +11 | # error: [invalid-argument-type] + | + +``` + +``` +info[revealed-type]: Revealed type + --> src/mdtest_snippet.py:10:13 + | + 9 | reveal_type(f(1)) # revealed: Literal[1] +10 | reveal_type(f(True)) # revealed: Literal[True] + | ^^^^^^^ `_T@reveal_type | Literal[True]` +11 | # error: [invalid-argument-type] +12 | reveal_type(f("string")) # revealed: Unknown + | + +``` + +``` +info[revealed-type]: Revealed type + --> src/mdtest_snippet.py:12:13 + | +10 | reveal_type(f(True)) # revealed: Literal[True] +11 | # error: [invalid-argument-type] +12 | reveal_type(f("string")) # revealed: Unknown + | ^^^^^^^^^^^ `_T@reveal_type` + | + +``` + +``` +error[invalid-argument-type]: Argument to function `f` is incorrect + --> src/mdtest_snippet.py:12:15 + | +10 | reveal_type(f(True)) # revealed: Literal[True] +11 | # error: [invalid-argument-type] +12 | reveal_type(f("string")) # revealed: Unknown + | ^^^^^^^^ Argument type `Literal["string"]` does not satisfy upper bound `int` of type variable `T` + | +info: Type variable defined here + --> src/mdtest_snippet.py:4:1 + | +2 | from typing_extensions import reveal_type +3 | +4 | T = TypeVar("T", bound=int) + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ +5 | +6 | def f(x: T) -> T: + | +info: rule `invalid-argument-type` is enabled by default + +``` diff --git a/crates/ty_python_semantic/resources/mdtest/snapshots/functions.md_-_Generic_functions___P…_-_Inferring_a_bound_ty…_(5935d14c26afe407).snap.new b/crates/ty_python_semantic/resources/mdtest/snapshots/functions.md_-_Generic_functions___P…_-_Inferring_a_bound_ty…_(5935d14c26afe407).snap.new new file mode 100644 index 0000000000..1d013bc63e --- /dev/null +++ b/crates/ty_python_semantic/resources/mdtest/snapshots/functions.md_-_Generic_functions___P…_-_Inferring_a_bound_ty…_(5935d14c26afe407).snap.new @@ -0,0 +1,88 @@ +--- +source: crates/ty_test/src/lib.rs +assertion_line: 421 +expression: snapshot +--- +--- +mdtest name: functions.md - Generic functions: PEP 695 syntax - Inferring a bound typevar +mdtest path: crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +--- + +# Python source files + +## mdtest_snippet.py + +``` +1 | from typing_extensions import reveal_type +2 | +3 | def f[T: int](x: T) -> T: +4 | return x +5 | +6 | reveal_type(f(1)) # revealed: Literal[1] +7 | reveal_type(f(True)) # revealed: Literal[True] +8 | # error: [invalid-argument-type] +9 | reveal_type(f("string")) # revealed: Unknown +``` + +# Diagnostics + +``` +info[revealed-type]: Revealed type + --> src/mdtest_snippet.py:6:13 + | +4 | return x +5 | +6 | reveal_type(f(1)) # revealed: Literal[1] + | ^^^^ `_T@reveal_type | Literal[1]` +7 | reveal_type(f(True)) # revealed: Literal[True] +8 | # error: [invalid-argument-type] + | + +``` + +``` +info[revealed-type]: Revealed type + --> src/mdtest_snippet.py:7:13 + | +6 | reveal_type(f(1)) # revealed: Literal[1] +7 | reveal_type(f(True)) # revealed: Literal[True] + | ^^^^^^^ `_T@reveal_type | Literal[True]` +8 | # error: [invalid-argument-type] +9 | reveal_type(f("string")) # revealed: Unknown + | + +``` + +``` +info[revealed-type]: Revealed type + --> src/mdtest_snippet.py:9:13 + | +7 | reveal_type(f(True)) # revealed: Literal[True] +8 | # error: [invalid-argument-type] +9 | reveal_type(f("string")) # revealed: Unknown + | ^^^^^^^^^^^ `_T@reveal_type` + | + +``` + +``` +error[invalid-argument-type]: Argument to function `f` is incorrect + --> src/mdtest_snippet.py:9:15 + | +7 | reveal_type(f(True)) # revealed: Literal[True] +8 | # error: [invalid-argument-type] +9 | reveal_type(f("string")) # revealed: Unknown + | ^^^^^^^^ Argument type `Literal["string"]` does not satisfy upper bound `int` of type variable `T` + | +info: Type variable defined here + --> src/mdtest_snippet.py:3:7 + | +1 | from typing_extensions import reveal_type +2 | +3 | def f[T: int](x: T) -> T: + | ^^^^^^ +4 | return x + | +info: rule `invalid-argument-type` is enabled by default + +``` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index b2cb2d61f1..75b09997a2 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4853,11 +4853,10 @@ impl<'db> Type<'db> { fn try_call( self, db: &'db dyn Db, - argument_types: &CallArguments<'_, 'db>, + arguments: &CallArguments<'_, 'db>, ) -> Result, CallError<'db>> { - self.bindings(db) - .match_parameters(db, argument_types) - .check_types(db, argument_types, &TypeContext::default()) + let (bindings, argument_types) = self.bindings(db).match_parameters(db, arguments); + bindings.check_types(db, arguments, &argument_types, &TypeContext::default()) } /// Look up a dunder method on the meta-type of `self` and call it. @@ -4889,7 +4888,7 @@ impl<'db> Type<'db> { self, db: &'db dyn Db, name: &str, - argument_types: &mut CallArguments<'_, 'db>, + arguments: &mut CallArguments<'_, 'db>, policy: MemberLookupPolicy, ) -> Result, CallDunderError<'db>> { // Implicit calls to dunder methods never access instance members, so we pass @@ -4903,10 +4902,14 @@ impl<'db> Type<'db> { .place { Place::Type(dunder_callable, boundness) => { - let bindings = dunder_callable - .bindings(db) - .match_parameters(db, argument_types) - .check_types(db, argument_types, &TypeContext::default())?; + let (bindings, argument_types) = + dunder_callable.bindings(db).match_parameters(db, arguments); + let bindings = bindings.check_types( + db, + arguments, + &argument_types, + &TypeContext::default(), + )?; if boundness == Boundness::PossiblyUnbound { return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); } @@ -5389,8 +5392,7 @@ impl<'db> Type<'db> { let new_call_outcome = new_method.and_then(|new_method| { match new_method.place.try_call_dunder_get(db, self_type) { Place::Type(new_method, boundness) => { - let result = - new_method.try_call(db, argument_types.with_self(Some(self_type)).as_ref()); + let result = new_method.try_call(db, &argument_types.with_self(self_type)); if boundness == Boundness::PossiblyUnbound { Some(Err(DunderNewCallError::PossiblyUnbound(result.err()))) } else { diff --git a/crates/ty_python_semantic/src/types/call.rs b/crates/ty_python_semantic/src/types/call.rs index 8c00ab3479..7390d9da13 100644 --- a/crates/ty_python_semantic/src/types/call.rs +++ b/crates/ty_python_semantic/src/types/call.rs @@ -6,7 +6,7 @@ use crate::types::call::bind::BindingError; mod arguments; pub(crate) mod bind; -pub(super) use arguments::{Argument, CallArguments}; +pub(super) use arguments::{Argument, CallArgumentTypes, CallArguments}; pub(super) use bind::{Binding, Bindings, CallableBinding, MatchedArgument}; /// Wraps a [`Bindings`] for an unsuccessful call with information about why the call was diff --git a/crates/ty_python_semantic/src/types/call/arguments.rs b/crates/ty_python_semantic/src/types/call/arguments.rs index fc8bf871e5..32625b2180 100644 --- a/crates/ty_python_semantic/src/types/call/arguments.rs +++ b/crates/ty_python_semantic/src/types/call/arguments.rs @@ -1,5 +1,3 @@ -use std::borrow::Cow; - use itertools::{Either, Itertools}; use ruff_python_ast as ast; @@ -90,21 +88,16 @@ impl<'a, 'db> CallArguments<'a, 'db> { self.types.iter().map(|ty| ty.unwrap_or_else(Type::unknown)) } - /// Prepend an optional extra synthetic argument (for a `self` or `cls` parameter) to the front - /// of this argument list. (If `bound_self` is none, we return the argument list - /// unmodified.) - pub(crate) fn with_self(&self, bound_self: Option>) -> Cow<'_, Self> { - if bound_self.is_some() { - let arguments = std::iter::once(Argument::Synthetic) - .chain(self.arguments.iter().copied()) - .collect(); - let types = std::iter::once(bound_self) - .chain(self.types.iter().copied()) - .collect(); - Cow::Owned(CallArguments { arguments, types }) - } else { - Cow::Borrowed(self) - } + /// Prepend a extra synthetic argument (for a `self` or `cls` parameter) to the front + /// of this argument list.) + pub(crate) fn with_self(&self, bound_self: Type<'db>) -> Self { + let arguments = std::iter::once(Argument::Synthetic) + .chain(self.arguments.iter().copied()) + .collect(); + let types = std::iter::once(Some(bound_self)) + .chain(self.types.iter().copied()) + .collect(); + CallArguments { arguments, types } } pub(crate) fn iter(&self) -> impl Iterator, Option>)> + '_ { @@ -117,32 +110,89 @@ impl<'a, 'db> CallArguments<'a, 'db> { (self.arguments.iter().copied()).zip(self.types.iter_mut()) } + pub(crate) fn iter_with_types<'iter>( + &'iter self, + argument_types: &'iter CallArgumentTypes<'db>, + ) -> impl Iterator, Option>)> + 'iter { + (self.arguments.iter().copied()).zip(argument_types.iter().copied()) + } +} + +impl<'a, 'db> FromIterator<(Argument<'a>, Option>)> for CallArguments<'a, 'db> { + fn from_iter(iter: T) -> Self + where + T: IntoIterator, Option>)>, + { + let (arguments, types) = iter.into_iter().unzip(); + Self { arguments, types } + } +} + +// TODO: Consider removing this, and go back to just storing nested `CallArguments`. +#[derive(Clone, Debug)] +pub(crate) struct CallArgumentTypes<'db> { + types: Vec>>, +} + +impl<'db> CallArgumentTypes<'db> { + pub(crate) fn new(types: Vec>>) -> Self { + Self { types } + } + + pub(crate) fn types(&self) -> &[Option>] { + &self.types + } + + pub(crate) fn len(&self) -> usize { + self.types.len() + } + + pub(crate) fn iter(&self) -> impl Iterator>> { + self.types.iter() + } + + pub(crate) fn iter_mut(&mut self) -> impl Iterator>> { + self.types.iter_mut() + } + + pub(crate) fn with_self(&self, bound_self: Type<'db>) -> Self { + let types = std::iter::once(Some(bound_self)) + .chain(self.types.iter().copied()) + .collect(); + + CallArgumentTypes { types } + } +} + +impl<'db> CallArgumentTypes<'db> { /// Returns an iterator on performing [argument type expansion]. /// /// Each element of the iterator represents a set of argument lists, where each argument list /// contains the same arguments, but with one or more of the argument types expanded. /// + /// The iterator will return an element for every argument, even if it could not be expanded. + /// /// [argument type expansion]: https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion - pub(super) fn expand(&self, db: &'db dyn Db) -> impl Iterator> + '_ { + pub(super) fn expand(&self, db: &'db dyn Db) -> impl Iterator> + '_ { /// Maximum number of argument lists that can be generated in a single expansion step. static MAX_EXPANSIONS: usize = 512; /// Represents the state of the expansion process. - enum State<'a, 'b, 'db> { + enum State<'a, 'db> { LimitReached(usize), - Expanding(ExpandingState<'a, 'b, 'db>), + Expanding(ExpandingState<'a, 'db>), } /// Represents the expanding state with either the initial types or the expanded types. /// /// This is useful to avoid cloning the initial types vector if none of the types can be /// expanded. - enum ExpandingState<'a, 'b, 'db> { - Initial(&'b Vec>>), - Expanded(Vec>), + enum ExpandingState<'a, 'db> { + Initial(&'a [Option>]), + Expanded(Vec>), } - impl<'db> ExpandingState<'_, '_, 'db> { + impl<'db> ExpandingState<'_, 'db> { fn len(&self) -> usize { match self { ExpandingState::Initial(_) => 1, @@ -152,11 +202,9 @@ impl<'a, 'db> CallArguments<'a, 'db> { fn iter(&self) -> impl Iterator>]> + '_ { match self { - ExpandingState::Initial(types) => { - Either::Left(std::iter::once(types.as_slice())) - } + ExpandingState::Initial(types) => Either::Left(std::iter::once(*types)), ExpandingState::Expanded(expanded) => { - Either::Right(expanded.iter().map(CallArguments::types)) + Either::Right(expanded.iter().map(CallArgumentTypes::types)) } } } @@ -165,7 +213,7 @@ impl<'a, 'db> CallArguments<'a, 'db> { let mut index = 0; std::iter::successors( - Some(State::Expanding(ExpandingState::Initial(&self.types))), + Some(State::Expanding(ExpandingState::Initial(self.types()))), move |previous| { let state = match previous { State::LimitReached(index) => return Some(State::LimitReached(*index)), @@ -174,7 +222,7 @@ impl<'a, 'db> CallArguments<'a, 'db> { // Find the next type that can be expanded. let expanded_types = loop { - let arg_type = self.types.get(index)?; + let arg_type = self.types().get(index)?; if let Some(arg_type) = arg_type { if let Some(expanded_types) = expand_type(db, *arg_type) { break expanded_types; @@ -198,10 +246,7 @@ impl<'a, 'db> CallArguments<'a, 'db> { for subtype in &expanded_types { let mut new_expanded_types = pre_expanded_types.to_vec(); new_expanded_types[index] = Some(*subtype); - expanded_arguments.push(CallArguments::new( - self.arguments.clone(), - new_expanded_types, - )); + expanded_arguments.push(CallArgumentTypes::new(new_expanded_types)); } } @@ -226,8 +271,8 @@ impl<'a, 'db> CallArguments<'a, 'db> { /// Represents a single element of the expansion process for argument types for [`expand`]. /// -/// [`expand`]: CallArguments::expand -pub(super) enum Expansion<'a, 'db> { +/// [`expand`]: CallArgumentTypes::expand +pub(super) enum Expansion<'db> { /// Indicates that the expansion process has reached the maximum number of argument lists /// that can be generated in a single step. /// @@ -237,17 +282,7 @@ pub(super) enum Expansion<'a, 'db> { /// Contains the expanded argument lists, where each list contains the same arguments, but with /// one or more of the argument types expanded. - Expanded(Vec>), -} - -impl<'a, 'db> FromIterator<(Argument<'a>, Option>)> for CallArguments<'a, 'db> { - fn from_iter(iter: T) -> Self - where - T: IntoIterator, Option>)>, - { - let (arguments, types) = iter.into_iter().unzip(); - Self { arguments, types } - } + Expanded(Vec>), } /// Returns `true` if the type can be expanded into its subtypes. diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 7182694780..5e27ea6e4c 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -3,8 +3,9 @@ //! [signatures][crate::types::signatures], we have to handle the fact that the callable might be a //! union of types, each of which might contain multiple overloads. +use std::borrow::Cow; use std::collections::HashSet; -use std::fmt; +use std::{fmt, iter}; use itertools::{Either, Itertools}; use ruff_db::parsed::parsed_module; @@ -16,7 +17,7 @@ use crate::Program; use crate::db::Db; use crate::dunder_all::dunder_all_names; use crate::place::{Boundness, Place}; -use crate::types::call::arguments::{Expansion, is_expandable_type}; +use crate::types::call::arguments::{CallArgumentTypes, is_expandable_type}; use crate::types::diagnostic::{ CALL_NON_CALLABLE, CONFLICTING_ARGUMENT_FORMS, INVALID_ARGUMENT_TYPE, MISSING_ARGUMENT, NO_MATCHING_OVERLOAD, PARAMETER_ALREADY_ASSIGNED, POSITIONAL_ONLY_PARAMETER_AS_KWARG, @@ -38,6 +39,42 @@ use crate::types::{ use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; use ruff_python_ast::{self as ast, PythonVersion}; +#[derive(Clone, Debug)] +pub(crate) struct MatchedCallArguments<'db> { + bindings: SmallVec<[BindingCallArguments<'db>; 1]>, +} + +impl<'db> MatchedCallArguments<'db> { + pub(crate) fn bindings(&self) -> impl Iterator> { + self.bindings.iter() + } + + pub(crate) fn bindings_mut(&mut self) -> impl Iterator> { + self.bindings.iter_mut() + } +} + +#[derive(Clone, Debug)] +pub(crate) struct BindingCallArguments<'db> { + overloads: SmallVec<[CallArgumentTypes<'db>; 1]>, +} + +impl<'a, 'db> From<&CallArguments<'a, 'db>> for CallArgumentTypes<'db> { + fn from(arguments: &CallArguments<'a, 'db>) -> Self { + CallArgumentTypes::new(arguments.types().to_owned()) + } +} + +impl<'db> BindingCallArguments<'db> { + pub(crate) fn overloads(&self) -> impl Iterator> { + self.overloads.iter() + } + + pub(crate) fn overloads_mut(&mut self) -> impl Iterator> { + self.overloads.iter_mut() + } +} + /// Binding information for a possible union of callables. At a call site, the arguments must be /// compatible with _all_ of the types in the union for the call to be valid. /// @@ -102,18 +139,33 @@ impl<'db> Bindings<'db> { /// /// Once you have argument types available, you can call [`check_types`][Self::check_types] to /// verify that each argument type is assignable to the corresponding parameter type. - pub(crate) fn match_parameters( + pub(crate) fn match_parameters<'a>( mut self, db: &'db dyn Db, - arguments: &CallArguments<'_, 'db>, - ) -> Self { + arguments: &CallArguments<'a, 'db>, + ) -> (Self, MatchedCallArguments<'db>) { let mut argument_forms = ArgumentForms::new(arguments.len()); for binding in &mut self.elements { binding.match_parameters(db, arguments, &mut argument_forms); } argument_forms.shrink_to_fit(); self.argument_forms = argument_forms; - self + + let binding_arguments = self + .elements + .iter() + .map(|element| BindingCallArguments { + overloads: (0..element.overloads.len()) + .map(|_| CallArgumentTypes::from(arguments)) + .collect(), + }) + .collect(); + + let arguments = MatchedCallArguments { + bindings: binding_arguments, + }; + + (self, arguments) } /// Verify that the type of each argument is assignable to type of the parameter that it was @@ -131,12 +183,13 @@ impl<'db> Bindings<'db> { pub(crate) fn check_types( mut self, db: &'db dyn Db, - argument_types: &CallArguments<'_, 'db>, + arguments: &CallArguments<'_, 'db>, + argument_types: &MatchedCallArguments<'db>, call_expression_tcx: &TypeContext<'db>, ) -> Result> { - for element in &mut self.elements { + for (element, argument_types) in iter::zip(&mut self.elements, argument_types.bindings()) { if let Some(mut updated_argument_forms) = - element.check_types(db, argument_types, call_expression_tcx) + element.check_types(db, arguments, argument_types, call_expression_tcx) { // If this element returned a new set of argument forms (indicating successful // argument type expansion), update the `Bindings` with these forms. @@ -1276,7 +1329,10 @@ impl<'db> CallableBinding<'db> { ) { // If this callable is a bound method, prepend the self instance onto the arguments list // before checking. - let arguments = arguments.with_self(self.bound_type); + let arguments = match self.bound_type { + None => Cow::Borrowed(arguments), + Some(bound_self) => Cow::Owned(arguments.with_self(bound_self)), + }; for overload in &mut self.overloads { overload.match_parameters(db, arguments.as_ref(), argument_forms); @@ -1286,12 +1342,27 @@ impl<'db> CallableBinding<'db> { fn check_types( &mut self, db: &'db dyn Db, - argument_types: &CallArguments<'_, 'db>, + arguments: &CallArguments<'_, 'db>, + argument_types: &BindingCallArguments<'db>, call_expression_tcx: &TypeContext<'db>, ) -> Option { // If this callable is a bound method, prepend the self instance onto the arguments list // before checking. - let argument_types = argument_types.with_self(self.bound_type); + let (arguments, argument_types) = match self.bound_type { + None => (Cow::Borrowed(arguments), Cow::Borrowed(argument_types)), + Some(bound_self) => { + let arguments = arguments.with_self(bound_self); + let overloads = argument_types + .overloads() + .map(|argument_types| argument_types.with_self(bound_self)) + .collect(); + + ( + Cow::Owned(arguments), + Cow::Owned(BindingCallArguments { overloads }), + ) + } + }; // Step 1: Check the result of the arity check which is done by `match_parameters` let matching_overload_indexes = match self.matching_overload_index() { @@ -1300,15 +1371,26 @@ impl<'db> CallableBinding<'db> { // still perform type checking for non-overloaded function to provide better user // experience. if let [overload] = self.overloads.as_mut_slice() { - overload.check_types(db, argument_types.as_ref(), call_expression_tcx); + overload.check_types( + db, + arguments.as_ref(), + &argument_types.overloads[0], + call_expression_tcx, + ); } + return None; } MatchingOverloadIndex::Single(index) => { // If only one candidate overload remains, it is the winning match. Evaluate it as // a regular (non-overloaded) call. self.matching_overload_index = Some(index); - self.overloads[index].check_types(db, argument_types.as_ref(), call_expression_tcx); + self.overloads[index].check_types( + db, + arguments.as_ref(), + &argument_types.overloads[index], + call_expression_tcx, + ); return None; } MatchingOverloadIndex::Multiple(indexes) => { @@ -1319,8 +1401,13 @@ impl<'db> CallableBinding<'db> { // Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine // whether it is compatible with the supplied argument list. - for (_, overload) in self.matching_overloads_mut() { - overload.check_types(db, argument_types.as_ref(), call_expression_tcx); + for (overload_index, overload) in self.matching_overloads_mut() { + overload.check_types( + db, + arguments.as_ref(), + &argument_types.overloads[overload_index], + call_expression_tcx, + ); } match self.matching_overload_index() { @@ -1352,7 +1439,8 @@ impl<'db> CallableBinding<'db> { // If two or more candidate overloads remain, proceed to step 5. self.filter_overloads_using_any_or_unknown( db, - argument_types.as_ref(), + &arguments, + &argument_types, &indexes, ); } @@ -1365,29 +1453,45 @@ impl<'db> CallableBinding<'db> { // Step 3: Perform "argument type expansion". Reference: // https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion - let mut expansions = argument_types.expand(db).peekable(); + let overload_expansions = argument_types + .overloads() + .filter_map(|argument_types| { + let mut expanded = argument_types.expand(db).peekable(); + if expanded.peek().is_some() { + Some(expanded) + } else { + None + } + }) + .collect::>(); - // Return early if there are no argument types to expand. - expansions.peek()?; + // Return early if there are no argument types to expand for any overload. + if overload_expansions.is_empty() { + return None; + } // At this point, there's at least one argument that can be expanded. // // This heuristic tries to detect if there's any need to perform argument type expansion or // not by checking whether there are any non-expandable argument type that cannot be // assigned to any of the overloads. - for (argument_index, (argument, argument_type)) in argument_types.iter().enumerate() { + for (argument_index, (argument, _)) in arguments.iter().enumerate() { // TODO: Remove `Keywords` once `**kwargs` support is added if matches!(argument, Argument::Synthetic | Argument::Keywords) { continue; } - let Some(argument_type) = argument_type else { - continue; - }; - if is_expandable_type(db, argument_type) { - continue; - } let mut is_argument_assignable_to_any_overload = false; - 'overload: for overload in &self.overloads { + 'overload: for (overload_index, overload) in self.overloads.iter().enumerate() { + let argument_type = + argument_types.overloads[overload_index].types()[argument_index]; + + let Some(argument_type) = argument_type else { + continue; + }; + if is_expandable_type(db, argument_type) { + continue; + } + for parameter_index in &overload.argument_matches[argument_index].parameters { let parameter_type = overload.signature.parameters()[*parameter_index] .annotated_type() @@ -1400,9 +1504,8 @@ impl<'db> CallableBinding<'db> { } if !is_argument_assignable_to_any_overload { tracing::debug!( - "Argument at {argument_index} (`{}`) is not assignable to any of the \ - remaining overloads, skipping argument type expansion", - argument_type.display(db) + "Argument at {argument_index} is not assignable to any of the \ + remaining overloads, skipping argument type expansion" ); return None; } @@ -1414,7 +1517,7 @@ impl<'db> CallableBinding<'db> { // the non-expanded argument types. let post_evaluation_snapshot = snapshotter.take(self); - for expansion in expansions { + for expansion in overload_expansions { let expanded_argument_lists = match expansion { Expansion::LimitReached(index) => { snapshotter.restore(self, post_evaluation_snapshot); @@ -1586,6 +1689,7 @@ impl<'db> CallableBinding<'db> { &mut self, db: &'db dyn Db, arguments: &CallArguments<'_, 'db>, + _argument_types: &BindingCallArguments<'db>, matching_overload_indexes: &[usize], ) { // These are the parameter indexes that matches the arguments that participate in the @@ -1599,7 +1703,7 @@ impl<'db> CallableBinding<'db> { // participating parameter indexes. let mut top_materialized_argument_types = vec![]; - for (argument_index, argument_type) in arguments.iter_types().enumerate() { + for (argument_index, (_, argument_type)) in arguments.iter().enumerate() { let mut first_parameter_type: Option> = None; let mut participating_parameter_index = None; @@ -1624,7 +1728,11 @@ impl<'db> CallableBinding<'db> { if let Some(parameter_index) = participating_parameter_index { participating_parameter_indexes.insert(parameter_index); - top_materialized_argument_types.push(argument_type.top_materialization(db)); + top_materialized_argument_types.push( + argument_type + .unwrap_or_else(Type::unknown) + .top_materialization(db), + ); } } @@ -1750,6 +1858,12 @@ impl<'db> CallableBinding<'db> { } } + /// Returns an iterator over the overloads for this call binding, including + /// those that did not match. + pub(crate) fn overloads(&self) -> impl Iterator> { + self.overloads.iter() + } + /// Returns an iterator over all the overloads that matched for this call binding. pub(crate) fn matching_overloads(&self) -> impl Iterator)> { self.overloads @@ -2378,6 +2492,7 @@ struct ArgumentTypeChecker<'a, 'db> { db: &'db dyn Db, signature: &'a Signature<'db>, arguments: &'a CallArguments<'a, 'db>, + argument_types: &'a CallArgumentTypes<'db>, argument_matches: &'a [MatchedArgument<'db>], parameter_tys: &'a mut [Option>], call_expression_tcx: &'a TypeContext<'db>, @@ -2392,6 +2507,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { db: &'db dyn Db, signature: &'a Signature<'db>, arguments: &'a CallArguments<'a, 'db>, + argument_types: &'a CallArgumentTypes<'db>, argument_matches: &'a [MatchedArgument<'db>], parameter_tys: &'a mut [Option>], call_expression_tcx: &'a TypeContext<'db>, @@ -2401,6 +2517,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { db, signature, arguments, + argument_types, argument_matches, parameter_tys, call_expression_tcx, @@ -2413,7 +2530,10 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { fn enumerate_argument_types( &self, ) -> impl Iterator, Argument<'a>, Type<'db>)> + 'a { - let mut iter = self.arguments.iter().enumerate(); + let mut iter = self + .arguments + .iter_with_types(self.argument_types) + .enumerate(); let mut num_synthetic_args = 0; std::iter::from_fn(move || { let (argument_index, (argument, argument_type)) = iter.next()?; @@ -2829,12 +2949,14 @@ impl<'db> Binding<'db> { &mut self, db: &'db dyn Db, arguments: &CallArguments<'_, 'db>, + argument_types: &CallArgumentTypes<'db>, call_expression_tcx: &TypeContext<'db>, ) { let mut checker = ArgumentTypeChecker::new( db, &self.signature, arguments, + argument_types, &self.argument_matches, &mut self.parameter_tys, call_expression_tcx, @@ -2877,11 +2999,12 @@ impl<'db> Binding<'db> { pub(crate) fn arguments_for_parameter<'a>( &'a self, - argument_types: &'a CallArguments<'a, 'db>, + arguments: &'a CallArguments<'a, 'db>, + argument_types: &'a CallArgumentTypes<'db>, parameter_index: usize, ) -> impl Iterator, Type<'db>)> + 'a { - argument_types - .iter() + arguments + .iter_with_types(argument_types) .zip(&self.argument_matches) .filter(move |(_, argument_matches)| { argument_matches.parameters.contains(¶meter_index) diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 0afb14480a..d17d59357a 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -14,6 +14,7 @@ use crate::semantic_index::symbol::Symbol; use crate::semantic_index::{ DeclarationWithConstraint, SemanticIndex, attribute_declarations, attribute_scopes, }; +use crate::types::call::CallArgumentTypes; use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::context::InferContext; use crate::types::diagnostic::{INVALID_LEGACY_TYPE_VARIABLE, INVALID_TYPE_ALIAS_TYPE}; @@ -4909,6 +4910,7 @@ impl KnownClass { index: &SemanticIndex<'db>, overload: &mut Binding<'db>, call_arguments: &CallArguments<'_, 'db>, + call_argument_types: &CallArgumentTypes<'db>, call_expression: &ast::ExprCall, ) { let db = context.db(); @@ -5127,7 +5129,7 @@ impl KnownClass { let elements = UnionType::new( db, overload - .arguments_for_parameter(call_arguments, 1) + .arguments_for_parameter(call_arguments, call_argument_types, 1) .map(|(_, ty)| ty) .collect::>(), ); diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index ed10c7b819..0ca3eeca7c 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -64,7 +64,7 @@ use crate::semantic_index::ast_ids::HasScopedUseId; use crate::semantic_index::definition::Definition; use crate::semantic_index::scope::ScopeId; use crate::semantic_index::{FileScopeId, SemanticIndex, semantic_index}; -use crate::types::call::{Binding, CallArguments}; +use crate::types::call::{Binding, CallArgumentTypes, CallArguments}; use crate::types::constraints::ConstraintSet; use crate::types::context::InferContext; use crate::types::diagnostic::{ @@ -1416,6 +1416,7 @@ impl KnownFunction { context: &InferContext<'db, '_>, overload: &mut Binding<'db>, call_arguments: &CallArguments<'_, 'db>, + call_argument_types: &CallArgumentTypes<'db>, call_expression: &ast::ExprCall, file: File, ) { @@ -1425,7 +1426,7 @@ impl KnownFunction { match self { KnownFunction::RevealType => { let revealed_type = overload - .arguments_for_parameter(call_arguments, 0) + .arguments_for_parameter(call_arguments, call_argument_types, 0) .fold(UnionBuilder::new(db), |builder, (_, ty)| builder.add(ty)) .build(); if let Some(builder) = diff --git a/crates/ty_python_semantic/src/types/ide_support.rs b/crates/ty_python_semantic/src/types/ide_support.rs index 1b00a41007..0fdf263242 100644 --- a/crates/ty_python_semantic/src/types/ide_support.rs +++ b/crates/ty_python_semantic/src/types/ide_support.rs @@ -877,7 +877,7 @@ pub fn call_signature_details<'db>( CallArguments::from_arguments(&call_expr.arguments, |_, splatted_value| { splatted_value.inferred_type(model) }); - let bindings = callable_type + let (bindings, _) = callable_type .bindings(db) .match_parameters(db, &call_arguments); diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 298c3b4644..740597edc7 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -44,6 +44,7 @@ use crate::semantic_index::symbol::{ScopedSymbolId, Symbol}; use crate::semantic_index::{ ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, place_table, }; +use crate::types::call::bind::MatchedCallArguments; use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind}; use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator}; use crate::types::context::{InNoTypeCheck, InferContext}; @@ -4917,6 +4918,51 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.infer_expression(expression, TypeContext::default()) } + fn infer_matched_argument_types<'a>( + &mut self, + ast_arguments: &ast::Arguments, + argument_types: &mut MatchedCallArguments<'db>, + bindings: &Bindings<'db>, + ) { + for (binding, argument_types) in iter::zip(bindings, argument_types.bindings_mut()) { + for (overload, argument_types) in + iter::zip(binding.overloads(), argument_types.overloads_mut()) + { + debug_assert_eq!(ast_arguments.len(), argument_types.len()); + debug_assert_eq!(argument_types.len(), bindings.argument_forms().len()); + + for (argument_index, argument_type, argument_form, ast_argument) in itertools::izip!( + 0.., + argument_types.iter_mut(), + bindings.argument_forms(), + ast_arguments.arguments_source_order() + ) { + let ast_argument = match ast_argument { + // We already inferred the type of splatted arguments. + ast::ArgOrKeyword::Arg(ast::Expr::Starred(_)) + | ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }) => continue, + ast::ArgOrKeyword::Arg(arg) => arg, + ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value, + }; + + for (parameter_index, variadic_argument_type) in + overload.argument_matches()[argument_index].iter() + { + if variadic_argument_type.is_some() { + continue; + } + + let parameter = &overload.signature.parameters()[parameter_index]; + let tcx = TypeContext::new(parameter.annotated_type()); + + *argument_type = + Some(self.infer_argument_type(ast_argument, *argument_form, tcx)); + } + } + } + } + } + fn infer_argument_types<'a>( &mut self, ast_arguments: &ast::Arguments, @@ -5082,7 +5128,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { return; } let previous = self.expressions.insert(expression.into(), ty); - assert_eq!(previous, None); + assert!(previous == None || previous == Some(ty)); } fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type<'db> { @@ -5977,10 +6023,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - let bindings = callable_type + let (bindings, mut call_argument_types) = callable_type .bindings(self.db()) .match_parameters(self.db(), &call_arguments); - self.infer_argument_types(arguments, &mut call_arguments, bindings.argument_forms()); + + self.infer_matched_argument_types(arguments, &mut call_argument_types, &bindings); // Validate `TypedDict` constructor calls after argument type inference if let Some(class_literal) = callable_type.into_class_literal() { @@ -5998,17 +6045,23 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - let mut bindings = match bindings.check_types(self.db(), &call_arguments, &tcx) { - Ok(bindings) => bindings, - Err(CallError(_, bindings)) => { - bindings.report_diagnostics(&self.context, call_expression.into()); - return bindings.return_type(self.db()); - } - }; + let mut bindings = + match bindings.check_types(self.db(), &call_arguments, &call_argument_types, &tcx) { + Ok(bindings) => bindings, + Err(CallError(_, bindings)) => { + bindings.report_diagnostics(&self.context, call_expression.into()); + return bindings.return_type(self.db()); + } + }; - for binding in &mut bindings { + for (binding, call_argument_types) in + iter::zip(&mut bindings, call_argument_types.bindings()) + { let binding_type = binding.callable_type; - for (_, overload) in binding.matching_overloads_mut() { + for ((_, overload), call_argument_types) in binding + .matching_overloads_mut() + .zip(call_argument_types.overloads()) + { match binding_type { Type::FunctionLiteral(function_literal) => { if let Some(known_function) = function_literal.known(self.db()) { @@ -6016,6 +6069,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { &self.context, overload, &call_arguments, + &call_argument_types, call_expression, self.file(), ); @@ -6028,6 +6082,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.index, overload, &call_arguments, + &call_argument_types, call_expression, ); } @@ -8549,7 +8604,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { specialize: impl FnOnce(&[Option>]) -> Type<'db>, ) -> Type<'db> { let slice_node = subscript.slice.as_ref(); - let call_argument_types = match slice_node { + let call_arguments = match slice_node { ast::Expr::Tuple(tuple) => { let arguments = CallArguments::positional( tuple.elts.iter().map(|elt| self.infer_type_expression(elt)), @@ -8563,10 +8618,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { _ => CallArguments::positional([self.infer_type_expression(slice_node)]), }; let binding = Binding::single(value_ty, generic_context.signature(self.db())); - let bindings = match Bindings::from(binding) - .match_parameters(self.db(), &call_argument_types) - .check_types(self.db(), &call_argument_types, &TypeContext::default()) - { + // TODO: Call directly on `Binding`, not `Bindings`, to avoid the indirection. + let (bindings, call_argument_types) = + Bindings::from(binding).match_parameters(self.db(), &call_arguments); + let bindings = match bindings.check_types( + self.db(), + &call_arguments, + &call_argument_types, + &TypeContext::default(), + ) { Ok(bindings) => bindings, Err(CallError(_, bindings)) => { bindings.report_diagnostics(&self.context, subscript.into());