diff --git a/crates/ty_python_semantic/resources/mdtest/call/overloads.md b/crates/ty_python_semantic/resources/mdtest/call/overloads.md index ca34c94b5e..a59de5d27b 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/call/overloads.md @@ -290,16 +290,10 @@ from overloaded import A, f def _(x: int, y: A | int): reveal_type(f(x)) # revealed: int - # TODO: revealed: int - # TODO: no error - # error: [no-matching-overload] - reveal_type(f(*(x,))) # revealed: Unknown + reveal_type(f(*(x,))) # revealed: int reveal_type(f(y)) # revealed: A | int - # TODO: revealed: A | int - # TODO: no error - # error: [no-matching-overload] - reveal_type(f(*(y,))) # revealed: Unknown + reveal_type(f(*(y,))) # revealed: A | int ``` ### Generics (PEP 695) @@ -328,16 +322,10 @@ from overloaded import B, f def _(x: int, y: B | int): reveal_type(f(x)) # revealed: int - # TODO: revealed: int - # TODO: no error - # error: [no-matching-overload] - reveal_type(f(*(x,))) # revealed: Unknown + reveal_type(f(*(x,))) # revealed: int reveal_type(f(y)) # revealed: B | int - # TODO: revealed: B | int - # TODO: no error - # error: [no-matching-overload] - reveal_type(f(*(y,))) # revealed: Unknown + reveal_type(f(*(y,))) # revealed: B | int ``` ### Expanding `bool` @@ -1236,21 +1224,14 @@ def _(integer: int, string: str, any: Any, list_any: list[Any]): reveal_type(f(*(integer, string))) # revealed: int reveal_type(f(string, integer)) # revealed: int - # TODO: revealed: int - # TODO: no error - # error: [no-matching-overload] - reveal_type(f(*(string, integer))) # revealed: Unknown + reveal_type(f(*(string, integer))) # revealed: int # This matches the second overload and is _not_ the case of ambiguous overload matching. reveal_type(f(string, any)) # revealed: Any - # TODO: Any - reveal_type(f(*(string, any))) # revealed: tuple[str, Any] + reveal_type(f(*(string, any))) # revealed: Any reveal_type(f(string, list_any)) # revealed: list[Any] - # TODO: revealed: list[Any] - # TODO: no error - # error: [no-matching-overload] - reveal_type(f(*(string, list_any))) # revealed: Unknown + reveal_type(f(*(string, list_any))) # revealed: list[Any] ``` ### Generic `self` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 5a910fc1c3..0b90b9a31c 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4843,7 +4843,7 @@ impl<'db> Type<'db> { argument_types: &CallArguments<'_, 'db>, ) -> Result, CallError<'db>> { self.bindings(db) - .match_parameters(argument_types) + .match_parameters(db, argument_types) .check_types(db, argument_types) } @@ -4892,7 +4892,7 @@ impl<'db> Type<'db> { Place::Type(dunder_callable, boundness) => { let bindings = dunder_callable .bindings(db) - .match_parameters(argument_types) + .match_parameters(db, argument_types) .check_types(db, argument_types)?; if boundness == Boundness::PossiblyUnbound { return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 7627602fa6..43c1aa90f2 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -7,7 +7,7 @@ use std::borrow::Cow; use std::collections::HashSet; use std::fmt; -use itertools::Itertools; +use itertools::{Either, Itertools}; use ruff_db::parsed::parsed_module; use smallvec::{SmallVec, smallvec, smallvec_inline}; @@ -101,11 +101,15 @@ impl<'db> Bindings<'db> { /// /// Once you have argument types available, you can call [`check_types`][Self::check_types] to /// verify that each argument type is assignable to the corresponding parameter type. - pub(crate) fn match_parameters(mut self, arguments: &CallArguments<'_, 'db>) -> Self { + pub(crate) fn match_parameters( + mut self, + db: &'db dyn Db, + arguments: &CallArguments<'_, 'db>, + ) -> Self { let mut argument_forms = vec![None; arguments.len()]; let mut conflicting_forms = vec![false; arguments.len()]; for binding in &mut self.elements { - binding.match_parameters(arguments, &mut argument_forms, &mut conflicting_forms); + binding.match_parameters(db, arguments, &mut argument_forms, &mut conflicting_forms); } self.argument_forms = argument_forms.into(); self.conflicting_forms = conflicting_forms.into(); @@ -1243,6 +1247,7 @@ impl<'db> CallableBinding<'db> { fn match_parameters( &mut self, + db: &'db dyn Db, arguments: &CallArguments<'_, 'db>, argument_forms: &mut [Option], conflicting_forms: &mut [bool], @@ -1252,7 +1257,7 @@ impl<'db> CallableBinding<'db> { let arguments = arguments.with_self(self.bound_type); for overload in &mut self.overloads { - overload.match_parameters(arguments.as_ref(), argument_forms, conflicting_forms); + overload.match_parameters(db, arguments.as_ref(), argument_forms, conflicting_forms); } } @@ -1903,7 +1908,7 @@ struct ArgumentMatcher<'a, 'db> { conflicting_forms: &'a mut [bool], errors: &'a mut Vec>, - argument_matches: Vec, + argument_matches: Vec>, parameter_matched: Vec, next_positional: usize, first_excess_positional: Option, @@ -1947,6 +1952,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { &mut self, argument_index: usize, argument: Argument<'a>, + argument_type: Option>, parameter_index: usize, parameter: &Parameter<'db>, positional: bool, @@ -1970,6 +1976,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { } let matched_argument = &mut self.argument_matches[argument_index]; matched_argument.parameters.push(parameter_index); + matched_argument.types.push(argument_type); matched_argument.matched = true; self.parameter_matched[parameter_index] = true; } @@ -1978,6 +1985,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { &mut self, argument_index: usize, argument: Argument<'a>, + argument_type: Option>, ) -> Result<(), ()> { if matches!(argument, Argument::Synthetic) { self.num_synthetic_args += 1; @@ -1996,6 +2004,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { self.assign_argument( argument_index, argument, + argument_type, parameter_index, parameter, !parameter.is_variadic(), @@ -2020,20 +2029,35 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { }); return Err(()); }; - self.assign_argument(argument_index, argument, parameter_index, parameter, false); + self.assign_argument( + argument_index, + argument, + None, + parameter_index, + parameter, + false, + ); Ok(()) } fn match_variadic( &mut self, + db: &'db dyn Db, argument_index: usize, argument: Argument<'a>, + argument_type: Option>, length: TupleLength, ) -> Result<(), ()> { + let tuple = argument_type.map(|ty| ty.iterate(db)); + let mut argument_types = match tuple.as_ref() { + Some(tuple) => Either::Left(tuple.all_elements().copied()), + None => Either::Right(std::iter::empty()), + }; + // We must be able to match up the fixed-length portion of the argument with positional // parameters, so we pass on any errors that occur. for _ in 0..length.minimum() { - self.match_positional(argument_index, argument)?; + self.match_positional(argument_index, argument, argument_types.next())?; } // If the tuple is variable-length, we assume that it will soak up all remaining positional @@ -2044,14 +2068,14 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { .get_positional(self.next_positional) .is_some() { - self.match_positional(argument_index, argument)?; + self.match_positional(argument_index, argument, argument_types.next())?; } } Ok(()) } - fn finish(self) -> Box<[MatchedArgument]> { + fn finish(self) -> Box<[MatchedArgument<'db>]> { if let Some(first_excess_argument_index) = self.first_excess_positional { self.errors.push(BindingError::TooManyPositionalArguments { first_excess_argument_index: self.get_argument_index(first_excess_argument_index), @@ -2088,7 +2112,7 @@ struct ArgumentTypeChecker<'a, 'db> { db: &'db dyn Db, signature: &'a Signature<'db>, arguments: &'a CallArguments<'a, 'db>, - argument_matches: &'a [MatchedArgument], + argument_matches: &'a [MatchedArgument<'db>], parameter_tys: &'a mut [Option>], errors: &'a mut Vec>, @@ -2101,7 +2125,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { db: &'db dyn Db, signature: &'a Signature<'db>, arguments: &'a CallArguments<'a, 'db>, - argument_matches: &'a [MatchedArgument], + argument_matches: &'a [MatchedArgument<'db>], parameter_tys: &'a mut [Option>], errors: &'a mut Vec>, ) -> Self { @@ -2156,12 +2180,17 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { for (argument_index, adjusted_argument_index, _, argument_type) in self.enumerate_argument_types() { - for parameter_index in &self.argument_matches[argument_index].parameters { - let parameter = ¶meters[*parameter_index]; + for (parameter_index, variadic_argument_type) in + self.argument_matches[argument_index].iter() + { + let parameter = ¶meters[parameter_index]; let Some(expected_type) = parameter.annotated_type() else { continue; }; - if let Err(error) = builder.infer(expected_type, argument_type) { + if let Err(error) = builder.infer( + expected_type, + variadic_argument_type.unwrap_or(argument_type), + ) { self.errors.push(BindingError::SpecializationError { error, argument_index: adjusted_argument_index, @@ -2305,7 +2334,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { /// Information about which parameter(s) an argument was matched against. This is tracked /// separately for each overload. #[derive(Clone, Debug, Default)] -pub struct MatchedArgument { +pub struct MatchedArgument<'db> { /// The index of the parameter(s) that an argument was matched against. A splatted argument /// might be matched against multiple parameters. pub parameters: SmallVec<[usize; 1]>, @@ -2314,6 +2343,33 @@ pub struct MatchedArgument { /// elements must have been successfully matched. (That means that this can be `false` while /// the `parameters` field is non-empty.) pub matched: bool, + + /// The types of a variadic argument when it's unpacked. + /// + /// The length of this vector is always the same as the `parameters` vector i.e., these are the + /// types assigned to each matched parameter. This isn't necessarily the same as the number of + /// types in the argument type which might not be a fixed-length iterable. + /// + /// Another thing to note is that the way this is populated means that for any other argument + /// kind (synthetic, positional, keyword, keyword-variadic), this will be a single-element + /// vector containing `None`, since we don't know the type of the argument when this is + /// constructed. So, this field is populated only for variadic arguments. + /// + /// For example, given a `*args` whose type is `tuple[A, B, C]` and the following parameters: + /// - `(x, *args)`: the `types` field will only have two elements (`B`, `C`) since `A` has been + /// matched with `x`. + /// - `(*args)`: the `types` field will have all the three elements (`A`, `B`, `C`) + types: SmallVec<[Option>; 1]>, +} + +impl<'db> MatchedArgument<'db> { + /// Returns an iterator over the parameter indices and the corresponding argument type. + pub fn iter(&self) -> impl Iterator>)> + '_ { + self.parameters + .iter() + .copied() + .zip(self.types.iter().copied()) + } } /// Binding information for one of the overloads of a callable. @@ -2341,7 +2397,7 @@ pub(crate) struct Binding<'db> { /// Information about which parameter(s) each argument was matched with, in argument source /// order. - argument_matches: Box<[MatchedArgument]>, + argument_matches: Box<[MatchedArgument<'db>]>, /// Bound types for parameters, in parameter source order, or `None` if no argument was matched /// to that parameter. @@ -2374,6 +2430,7 @@ impl<'db> Binding<'db> { pub(crate) fn match_parameters( &mut self, + db: &'db dyn Db, arguments: &CallArguments<'_, 'db>, argument_forms: &mut [Option], conflicting_forms: &mut [bool], @@ -2386,16 +2443,17 @@ impl<'db> Binding<'db> { conflicting_forms, &mut self.errors, ); - for (argument_index, (argument, _)) in arguments.iter().enumerate() { + for (argument_index, (argument, argument_type)) in arguments.iter().enumerate() { match argument { Argument::Positional | Argument::Synthetic => { - let _ = matcher.match_positional(argument_index, argument); + let _ = matcher.match_positional(argument_index, argument, None); } Argument::Keyword(name) => { let _ = matcher.match_keyword(argument_index, argument, name); } Argument::Variadic(length) => { - let _ = matcher.match_variadic(argument_index, argument, length); + let _ = + matcher.match_variadic(db, argument_index, argument, argument_type, length); } Argument::Keywords => { // TODO @@ -2532,7 +2590,7 @@ impl<'db> Binding<'db> { /// Returns a vector where each index corresponds to an argument position, /// and the value is the parameter index that argument maps to (if any). - pub(crate) fn argument_matches(&self) -> &[MatchedArgument] { + pub(crate) fn argument_matches(&self) -> &[MatchedArgument<'db>] { &self.argument_matches } } @@ -2542,7 +2600,7 @@ struct BindingSnapshot<'db> { return_ty: Type<'db>, specialization: Option>, inherited_specialization: Option>, - argument_matches: Box<[MatchedArgument]>, + argument_matches: Box<[MatchedArgument<'db>]>, parameter_tys: Box<[Option>]>, errors: Vec>, } diff --git a/crates/ty_python_semantic/src/types/ide_support.rs b/crates/ty_python_semantic/src/types/ide_support.rs index f96401bd2f..328f5d20d3 100644 --- a/crates/ty_python_semantic/src/types/ide_support.rs +++ b/crates/ty_python_semantic/src/types/ide_support.rs @@ -801,7 +801,7 @@ pub struct CallSignatureDetails<'db> { /// Mapping from argument indices to parameter indices. This helps /// determine which parameter corresponds to which argument position. - pub argument_to_parameter_mapping: Vec, + pub argument_to_parameter_mapping: Vec>, } /// Extract signature details from a function call expression. @@ -821,7 +821,9 @@ pub fn call_signature_details<'db>( CallArguments::from_arguments(db, &call_expr.arguments, |_, splatted_value| { splatted_value.inferred_type(model) }); - let bindings = callable_type.bindings(db).match_parameters(&call_arguments); + let bindings = callable_type + .bindings(db) + .match_parameters(db, &call_arguments); // Extract signature details from all callable bindings bindings diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index f190205d6f..2fa57298bd 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -6360,7 +6360,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let bindings = callable_type .bindings(self.db()) - .match_parameters(&call_arguments); + .match_parameters(self.db(), &call_arguments); self.infer_argument_types(arguments, &mut call_arguments, &bindings.argument_forms); // Validate `TypedDict` constructor calls after argument type inference @@ -8812,7 +8812,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; let binding = Binding::single(value_ty, generic_context.signature(self.db())); let bindings = match Bindings::from(binding) - .match_parameters(&call_argument_types) + .match_parameters(self.db(), &call_argument_types) .check_types(self.db(), &call_argument_types) { Ok(bindings) => bindings,