diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 2bb6a955eb..e8fce04ec0 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -35,7 +35,7 @@ use crate::semantic_index::definition::Definition; use crate::semantic_index::place::{ScopeId, ScopedPlaceId}; use crate::semantic_index::{imported_modules, place_table, semantic_index}; use crate::suppression::check_suppressions; -use crate::types::call::{Binding, Bindings, CallArgumentTypes, CallableBinding}; +use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding}; pub(crate) use crate::types::class_base::ClassBase; use crate::types::context::{LintDiagnosticGuard, LintDiagnosticGuardBuilder}; use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION}; @@ -2683,7 +2683,7 @@ impl<'db> Type<'db> { if let Place::Type(descr_get, descr_get_boundness) = descr_get { let return_ty = descr_get - .try_call(db, &CallArgumentTypes::positional([self, instance, owner])) + .try_call(db, &CallArguments::positional([self, instance, owner])) .map(|bindings| { if descr_get_boundness == Boundness::Bound { bindings.return_type(db) @@ -3134,7 +3134,7 @@ impl<'db> Type<'db> { self.try_call_dunder( db, "__getattr__", - CallArgumentTypes::positional([Type::string_literal(db, &name)]), + CallArguments::positional([Type::string_literal(db, &name)]), ) .map(|outcome| Place::bound(outcome.return_type(db))) // TODO: Handle call errors here. @@ -3153,7 +3153,7 @@ impl<'db> Type<'db> { self.try_call_dunder_with_policy( db, "__getattribute__", - &mut CallArgumentTypes::positional([Type::string_literal(db, &name)]), + &mut CallArguments::positional([Type::string_literal(db, &name)]), MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK, ) .map(|outcome| Place::bound(outcome.return_type(db))) @@ -3275,7 +3275,7 @@ impl<'db> Type<'db> { // runtime there is a fallback to `__len__`, since `__bool__` takes precedence // and a subclass could add a `__bool__` method. - match self.try_call_dunder(db, "__bool__", CallArgumentTypes::none()) { + match self.try_call_dunder(db, "__bool__", CallArguments::none()) { Ok(outcome) => { let return_type = outcome.return_type(db); if !return_type.is_assignable_to(db, KnownClass::Bool.to_instance(db)) { @@ -3474,7 +3474,7 @@ impl<'db> Type<'db> { return usize_len.try_into().ok().map(Type::IntLiteral); } - let return_ty = match self.try_call_dunder(db, "__len__", CallArgumentTypes::none()) { + let return_ty = match self.try_call_dunder(db, "__len__", CallArguments::none()) { Ok(bindings) => bindings.return_type(db), Err(CallDunderError::PossiblyUnbound(bindings)) => bindings.return_type(db), @@ -4394,7 +4394,7 @@ impl<'db> Type<'db> { fn try_call( self, db: &'db dyn Db, - argument_types: &CallArgumentTypes<'_, 'db>, + argument_types: &CallArguments<'_, 'db>, ) -> Result, CallError<'db>> { self.bindings(db) .match_parameters(argument_types) @@ -4409,7 +4409,7 @@ impl<'db> Type<'db> { self, db: &'db dyn Db, name: &str, - mut argument_types: CallArgumentTypes<'_, 'db>, + mut argument_types: CallArguments<'_, 'db>, ) -> Result, CallDunderError<'db>> { self.try_call_dunder_with_policy( db, @@ -4430,7 +4430,7 @@ impl<'db> Type<'db> { self, db: &'db dyn Db, name: &str, - argument_types: &mut CallArgumentTypes<'_, 'db>, + argument_types: &mut CallArguments<'_, 'db>, policy: MemberLookupPolicy, ) -> Result, CallDunderError<'db>> { // Implicit calls to dunder methods never access instance members, so we pass @@ -4492,19 +4492,19 @@ impl<'db> Type<'db> { self.try_call_dunder( db, "__getitem__", - CallArgumentTypes::positional([KnownClass::Int.to_instance(db)]), + CallArguments::positional([KnownClass::Int.to_instance(db)]), ) .map(|dunder_getitem_outcome| dunder_getitem_outcome.return_type(db)) }; let try_call_dunder_next_on_iterator = |iterator: Type<'db>| { iterator - .try_call_dunder(db, "__next__", CallArgumentTypes::none()) + .try_call_dunder(db, "__next__", CallArguments::none()) .map(|dunder_next_outcome| dunder_next_outcome.return_type(db)) }; let dunder_iter_result = self - .try_call_dunder(db, "__iter__", CallArgumentTypes::none()) + .try_call_dunder(db, "__iter__", CallArguments::none()) .map(|dunder_iter_outcome| dunder_iter_outcome.return_type(db)); match dunder_iter_result { @@ -4588,11 +4588,11 @@ impl<'db> Type<'db> { /// pass /// ``` fn try_enter(self, db: &'db dyn Db) -> Result, ContextManagerError<'db>> { - let enter = self.try_call_dunder(db, "__enter__", CallArgumentTypes::none()); + let enter = self.try_call_dunder(db, "__enter__", CallArguments::none()); let exit = self.try_call_dunder( db, "__exit__", - CallArgumentTypes::positional([Type::none(db), Type::none(db), Type::none(db)]), + CallArguments::positional([Type::none(db), Type::none(db), Type::none(db)]), ); // TODO: Make use of Protocols when we support it (the manager be assignable to `contextlib.AbstractContextManager`). @@ -4627,7 +4627,7 @@ impl<'db> Type<'db> { fn try_call_constructor( self, db: &'db dyn Db, - argument_types: CallArgumentTypes<'_, 'db>, + argument_types: CallArguments<'_, 'db>, ) -> Result, ConstructorCallError<'db>> { debug_assert!(matches!( self, @@ -6428,11 +6428,11 @@ impl<'db> ContextManagerError<'db> { Ok(_) | Err(CallDunderError::CallError(..)), Ok(_) | Err(CallDunderError::CallError(..)), ) = ( - context_expression_type.try_call_dunder(db, "__aenter__", CallArgumentTypes::none()), + context_expression_type.try_call_dunder(db, "__aenter__", CallArguments::none()), context_expression_type.try_call_dunder( db, "__aexit__", - CallArgumentTypes::positional([Type::unknown(), Type::unknown(), Type::unknown()]), + CallArguments::positional([Type::unknown(), Type::unknown(), Type::unknown()]), ), ) { diag.info(format_args!( @@ -6495,7 +6495,7 @@ impl<'db> IterationError<'db> { Self::IterCallError(_, dunder_iter_bindings) => dunder_iter_bindings .return_type(db) - .try_call_dunder(db, "__next__", CallArgumentTypes::none()) + .try_call_dunder(db, "__next__", CallArguments::none()) .map(|dunder_next_outcome| Some(dunder_next_outcome.return_type(db))) .unwrap_or_else(|dunder_next_call_error| dunder_next_call_error.return_type(db)), diff --git a/crates/ty_python_semantic/src/types/call.rs b/crates/ty_python_semantic/src/types/call.rs index f77cc429a5..c34489bbbc 100644 --- a/crates/ty_python_semantic/src/types/call.rs +++ b/crates/ty_python_semantic/src/types/call.rs @@ -4,7 +4,7 @@ use crate::Db; mod arguments; pub(crate) mod bind; -pub(super) use arguments::{Argument, CallArgumentTypes, CallArguments}; +pub(super) use arguments::{Argument, CallArguments}; pub(super) use bind::{Binding, Bindings, CallableBinding}; /// Wraps a [`Bindings`] for an unsuccessful call with information about why the call was diff --git a/crates/ty_python_semantic/src/types/call/arguments.rs b/crates/ty_python_semantic/src/types/call/arguments.rs index 77a4eb3976..89c3cf0112 100644 --- a/crates/ty_python_semantic/src/types/call/arguments.rs +++ b/crates/ty_python_semantic/src/types/call/arguments.rs @@ -1,5 +1,4 @@ use std::borrow::Cow; -use std::ops::{Deref, DerefMut}; use itertools::{Either, Itertools}; use ruff_python_ast as ast; @@ -10,60 +9,6 @@ use crate::types::tuple::{TupleSpec, TupleType}; use super::Type; -/// Arguments for a single call, in source order. -#[derive(Clone, Debug, Default)] -pub(crate) struct CallArguments<'a>(Vec>); - -impl<'a> CallArguments<'a> { - /// Create `CallArguments` from AST arguments - pub(crate) fn from_arguments(arguments: &'a ast::Arguments) -> Self { - arguments - .arguments_source_order() - .map(|arg_or_keyword| match arg_or_keyword { - ast::ArgOrKeyword::Arg(arg) => match arg { - ast::Expr::Starred(ast::ExprStarred { .. }) => Argument::Variadic, - _ => Argument::Positional, - }, - ast::ArgOrKeyword::Keyword(ast::Keyword { arg, .. }) => { - if let Some(arg) = arg { - Argument::Keyword(&arg.id) - } else { - Argument::Keywords - } - } - }) - .collect() - } - - /// Prepend an optional extra synthetic argument (for a `self` or `cls` parameter) to the front - /// of this argument list. (If `bound_self` is none, we return the argument list - /// unmodified.) - pub(crate) fn with_self(&self, bound_self: Option>) -> Cow { - if bound_self.is_some() { - let arguments = std::iter::once(Argument::Synthetic) - .chain(self.0.iter().copied()) - .collect(); - Cow::Owned(CallArguments(arguments)) - } else { - Cow::Borrowed(self) - } - } - - pub(crate) fn len(&self) -> usize { - self.0.len() - } - - pub(crate) fn iter(&self) -> impl Iterator> + '_ { - self.0.iter().copied() - } -} - -impl<'a> FromIterator> for CallArguments<'a> { - fn from_iter>>(iter: T) -> Self { - Self(iter.into_iter().collect()) - } -} - #[derive(Clone, Copy, Debug)] pub(crate) enum Argument<'a> { /// The synthetic `self` or `cls` argument, which doesn't appear explicitly at the call site. @@ -80,64 +25,87 @@ pub(crate) enum Argument<'a> { /// Arguments for a single call, in source order, along with inferred types for each argument. #[derive(Clone, Debug, Default)] -pub(crate) struct CallArgumentTypes<'a, 'db> { - arguments: CallArguments<'a>, - types: Vec>, +pub(crate) struct CallArguments<'a, 'db> { + arguments: Vec>, + types: Vec>>, } -impl<'a, 'db> CallArgumentTypes<'a, 'db> { - /// Create a [`CallArgumentTypes`] with no arguments. +impl<'a, 'db> CallArguments<'a, 'db> { + fn new(arguments: Vec>, types: Vec>>) -> Self { + debug_assert!(arguments.len() == types.len()); + Self { arguments, types } + } + + /// Create `CallArguments` from AST arguments + pub(crate) fn from_arguments(arguments: &'a ast::Arguments) -> Self { + arguments + .arguments_source_order() + .map(|arg_or_keyword| match arg_or_keyword { + ast::ArgOrKeyword::Arg(arg) => match arg { + ast::Expr::Starred(ast::ExprStarred { .. }) => Argument::Variadic, + _ => Argument::Positional, + }, + ast::ArgOrKeyword::Keyword(ast::Keyword { arg, .. }) => { + if let Some(arg) = arg { + Argument::Keyword(&arg.id) + } else { + Argument::Keywords + } + } + }) + .map(|argument| (argument, None)) + .collect() + } + + /// Create a [`CallArguments`] with no arguments. pub(crate) fn none() -> Self { Self::default() } - /// Create a [`CallArgumentTypes`] from an iterator over non-variadic positional argument - /// types. + /// Create a [`CallArguments`] from an iterator over non-variadic positional argument types. pub(crate) fn positional(positional_tys: impl IntoIterator>) -> Self { - let types: Vec<_> = positional_tys.into_iter().collect(); - let arguments = CallArguments(vec![Argument::Positional; types.len()]); + let types: Vec<_> = positional_tys.into_iter().map(Some).collect(); + let arguments = vec![Argument::Positional; types.len()]; Self { arguments, types } } - /// Create a new [`CallArgumentTypes`] to store the inferred types of the arguments in a - /// [`CallArguments`]. Uses the provided callback to infer each argument type. - pub(crate) fn new(arguments: CallArguments<'a>, mut f: F) -> Self - where - F: FnMut(usize, Argument<'a>) -> Type<'db>, - { - let types = arguments - .iter() - .enumerate() - .map(|(idx, argument)| f(idx, argument)) - .collect(); - Self { arguments, types } + pub(crate) fn len(&self) -> usize { + self.arguments.len() } - pub(crate) fn types(&self) -> &[Type<'db>] { + pub(crate) fn types(&self) -> &[Option>] { &self.types } + pub(crate) fn iter_types(&self) -> impl Iterator> { + self.types.iter().map(|ty| ty.unwrap_or_else(Type::unknown)) + } + /// Prepend an optional extra synthetic argument (for a `self` or `cls` parameter) to the front /// of this argument list. (If `bound_self` is none, we return the argument list /// unmodified.) pub(crate) fn with_self(&self, bound_self: Option>) -> Cow { - if let Some(bound_self) = bound_self { - let arguments = CallArguments( - std::iter::once(Argument::Synthetic) - .chain(self.arguments.0.iter().copied()) - .collect(), - ); + if bound_self.is_some() { + let arguments = std::iter::once(Argument::Synthetic) + .chain(self.arguments.iter().copied()) + .collect(); let types = std::iter::once(bound_self) .chain(self.types.iter().copied()) .collect(); - Cow::Owned(CallArgumentTypes { arguments, types }) + Cow::Owned(CallArguments { arguments, types }) } else { Cow::Borrowed(self) } } - pub(crate) fn iter(&self) -> impl Iterator, Type<'db>)> + '_ { - self.arguments.iter().zip(self.types.iter().copied()) + pub(crate) fn iter(&self) -> impl Iterator, Option>)> + '_ { + (self.arguments.iter().copied()).zip(self.types.iter().copied()) + } + + pub(crate) fn iter_mut( + &mut self, + ) -> impl Iterator, &mut Option>)> + '_ { + (self.arguments.iter().copied()).zip(self.types.iter_mut()) } /// Returns an iterator on performing [argument type expansion]. @@ -146,17 +114,20 @@ impl<'a, 'db> CallArgumentTypes<'a, 'db> { /// contains the same arguments, but with one or more of the argument types expanded. /// /// [argument type expansion]: https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion - pub(crate) fn expand(&self, db: &'db dyn Db) -> impl Iterator>>> + '_ { + pub(crate) fn expand( + &self, + db: &'db dyn Db, + ) -> impl Iterator>> + '_ { /// Represents the state of the expansion process. /// /// This is useful to avoid cloning the initial types vector if none of the types can be /// expanded. - enum State<'a, 'db> { - Initial(&'a Vec>), - Expanded(Vec>>), + enum State<'a, 'b, 'db> { + Initial(&'b Vec>>), + Expanded(Vec>), } - impl<'db> State<'_, 'db> { + impl<'db> State<'_, '_, 'db> { fn len(&self) -> usize { match self { State::Initial(_) => 1, @@ -164,10 +135,12 @@ impl<'a, 'db> CallArgumentTypes<'a, 'db> { } } - fn iter(&self) -> impl Iterator>> + '_ { + fn iter(&self) -> impl Iterator>]> + '_ { match self { - State::Initial(types) => std::slice::from_ref(*types).iter(), - State::Expanded(expanded) => expanded.iter(), + State::Initial(types) => Either::Left(std::iter::once(types.as_slice())), + State::Expanded(expanded) => { + Either::Right(expanded.iter().map(CallArguments::types)) + } } } } @@ -178,26 +151,31 @@ impl<'a, 'db> CallArgumentTypes<'a, 'db> { // Find the next type that can be expanded. let expanded_types = loop { let arg_type = self.types.get(index)?; - if let Some(expanded_types) = expand_type(db, *arg_type) { - break expanded_types; + if let Some(arg_type) = arg_type { + if let Some(expanded_types) = expand_type(db, *arg_type) { + break expanded_types; + } } index += 1; }; - let mut expanded_arg_types = Vec::with_capacity(expanded_types.len() * previous.len()); + let mut expanded_arguments = Vec::with_capacity(expanded_types.len() * previous.len()); for pre_expanded_types in previous.iter() { for subtype in &expanded_types { - let mut new_expanded_types = pre_expanded_types.clone(); - new_expanded_types[index] = *subtype; - expanded_arg_types.push(new_expanded_types); + let mut new_expanded_types = pre_expanded_types.to_vec(); + new_expanded_types[index] = Some(*subtype); + expanded_arguments.push(CallArguments::new( + self.arguments.clone(), + new_expanded_types, + )); } } // Increment the index to move to the next argument type for the next iteration. index += 1; - Some(State::Expanded(expanded_arg_types)) + Some(State::Expanded(expanded_arguments)) }) .skip(1) // Skip the initial state, which has no expanded types. .map(|state| match state { @@ -207,16 +185,13 @@ impl<'a, 'db> CallArgumentTypes<'a, 'db> { } } -impl<'a> Deref for CallArgumentTypes<'a, '_> { - type Target = CallArguments<'a>; - fn deref(&self) -> &CallArguments<'a> { - &self.arguments - } -} - -impl<'a> DerefMut for CallArgumentTypes<'a, '_> { - fn deref_mut(&mut self) -> &mut CallArguments<'a> { - &mut self.arguments +impl<'a, 'db> FromIterator<(Argument<'a>, Option>)> for CallArguments<'a, 'db> { + fn from_iter(iter: T) -> Self + where + T: IntoIterator, Option>)>, + { + let (arguments, types) = iter.into_iter().unzip(); + Self { arguments, types } } } diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index e8cfeb8311..b0157c2e81 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -10,10 +10,7 @@ use itertools::Itertools; use ruff_db::parsed::parsed_module; use smallvec::{SmallVec, smallvec}; -use super::{ - Argument, CallArgumentTypes, CallArguments, CallError, CallErrorKind, InferContext, Signature, - Type, -}; +use super::{Argument, CallArguments, CallError, CallErrorKind, InferContext, Signature, Type}; use crate::db::Db; use crate::dunder_all::dunder_all_names; use crate::place::{Boundness, Place}; @@ -95,12 +92,11 @@ impl<'db> Bindings<'db> { /// /// The returned bindings tell you which parameter (in each signature) each argument was /// matched against. You can then perform type inference on each argument with extra context - /// about the expected parameter types. (You do this by creating a [`CallArgumentTypes`] object - /// from the `arguments` that you match against.) + /// about the expected parameter types. /// /// Once you have argument types available, you can call [`check_types`][Self::check_types] to /// verify that each argument type is assignable to the corresponding parameter type. - pub(crate) fn match_parameters(mut self, arguments: &CallArguments<'_>) -> Self { + pub(crate) fn match_parameters(mut self, arguments: &CallArguments<'_, 'db>) -> Self { let mut argument_forms = vec![None; arguments.len()]; let mut conflicting_forms = vec![false; arguments.len()]; for binding in &mut self.elements { @@ -123,7 +119,7 @@ impl<'db> Bindings<'db> { pub(crate) fn check_types( mut self, db: &'db dyn Db, - argument_types: &CallArgumentTypes<'_, 'db>, + argument_types: &CallArguments<'_, 'db>, ) -> Result> { for element in &mut self.elements { element.check_types(db, argument_types); @@ -410,7 +406,7 @@ impl<'db> Bindings<'db> { [Some(Type::PropertyInstance(property)), Some(instance), ..] => { if let Some(getter) = property.getter(db) { if let Ok(return_ty) = getter - .try_call(db, &CallArgumentTypes::positional([*instance])) + .try_call(db, &CallArguments::positional([*instance])) .map(|binding| binding.return_type(db)) { overload.set_return_type(return_ty); @@ -439,7 +435,7 @@ impl<'db> Bindings<'db> { [Some(instance), ..] => { if let Some(getter) = property.getter(db) { if let Ok(return_ty) = getter - .try_call(db, &CallArgumentTypes::positional([*instance])) + .try_call(db, &CallArguments::positional([*instance])) .map(|binding| binding.return_type(db)) { overload.set_return_type(return_ty); @@ -469,10 +465,9 @@ impl<'db> Bindings<'db> { ] = overload.parameter_types() { if let Some(setter) = property.setter(db) { - if let Err(_call_error) = setter.try_call( - db, - &CallArgumentTypes::positional([*instance, *value]), - ) { + if let Err(_call_error) = setter + .try_call(db, &CallArguments::positional([*instance, *value])) + { overload.errors.push(BindingError::InternalCallError( "calling the setter failed", )); @@ -488,10 +483,9 @@ impl<'db> Bindings<'db> { Type::MethodWrapper(MethodWrapperKind::PropertyDunderSet(property)) => { if let [Some(instance), Some(value), ..] = overload.parameter_types() { if let Some(setter) = property.setter(db) { - if let Err(_call_error) = setter.try_call( - db, - &CallArgumentTypes::positional([*instance, *value]), - ) { + if let Err(_call_error) = setter + .try_call(db, &CallArguments::positional([*instance, *value])) + { overload.errors.push(BindingError::InternalCallError( "calling the setter failed", )); @@ -1161,7 +1155,7 @@ impl<'db> CallableBinding<'db> { fn match_parameters( &mut self, - arguments: &CallArguments<'_>, + arguments: &CallArguments<'_, 'db>, argument_forms: &mut [Option], conflicting_forms: &mut [bool], ) { @@ -1174,7 +1168,7 @@ impl<'db> CallableBinding<'db> { } } - fn check_types(&mut self, db: &'db dyn Db, argument_types: &CallArgumentTypes<'_, 'db>) { + fn check_types(&mut self, db: &'db dyn Db, argument_types: &CallArguments<'_, 'db>) { // If this callable is a bound method, prepend the self instance onto the arguments list // before checking. let argument_types = argument_types.with_self(self.bound_type); @@ -1186,7 +1180,7 @@ impl<'db> CallableBinding<'db> { // still perform type checking for non-overloaded function to provide better user // experience. if let [overload] = self.overloads.as_mut_slice() { - overload.check_types(db, argument_types.as_ref(), argument_types.types()); + overload.check_types(db, argument_types.as_ref()); } return; } @@ -1194,11 +1188,7 @@ impl<'db> CallableBinding<'db> { // If only one candidate overload remains, it is the winning match. Evaluate it as // a regular (non-overloaded) call. self.matching_overload_index = Some(index); - self.overloads[index].check_types( - db, - argument_types.as_ref(), - argument_types.types(), - ); + self.overloads[index].check_types(db, argument_types.as_ref()); return; } MatchingOverloadIndex::Multiple(indexes) => { @@ -1216,7 +1206,7 @@ impl<'db> CallableBinding<'db> { // Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine // whether it is compatible with the supplied argument list. for (_, overload) in self.matching_overloads_mut() { - overload.check_types(db, argument_types.as_ref(), argument_types.types()); + overload.check_types(db, argument_types.as_ref()); } match self.matching_overload_index() { @@ -1232,7 +1222,7 @@ impl<'db> CallableBinding<'db> { // TODO: Step 4 // Step 5 - self.filter_overloads_using_any_or_unknown(db, argument_types.types(), &indexes); + self.filter_overloads_using_any_or_unknown(db, argument_types.as_ref(), &indexes); // We're returning here because this shouldn't lead to argument type expansion. return; @@ -1268,7 +1258,7 @@ impl<'db> CallableBinding<'db> { let pre_evaluation_snapshot = snapshotter.take(self); for (_, overload) in self.matching_overloads_mut() { - overload.check_types(db, argument_types.as_ref(), expanded_argument_types); + overload.check_types(db, expanded_argument_types); } let return_type = match self.matching_overload_index() { @@ -1358,7 +1348,7 @@ impl<'db> CallableBinding<'db> { fn filter_overloads_using_any_or_unknown( &mut self, db: &'db dyn Db, - argument_types: &[Type<'db>], + arguments: &CallArguments<'_, 'db>, matching_overload_indexes: &[usize], ) { // These are the parameter indexes that matches the arguments that participate in the @@ -1372,7 +1362,7 @@ impl<'db> CallableBinding<'db> { // participating parameter indexes. let mut top_materialized_argument_types = vec![]; - for (argument_index, argument_type) in argument_types.iter().enumerate() { + for (argument_index, argument_type) in arguments.iter_types().enumerate() { let mut first_parameter_type: Option> = None; let mut participating_parameter_index = None; @@ -1415,8 +1405,8 @@ impl<'db> CallableBinding<'db> { self.overloads[*current_index].mark_as_unmatched_overload(); continue; } - let mut parameter_types = Vec::with_capacity(argument_types.len()); - for argument_index in 0..argument_types.len() { + let mut parameter_types = Vec::with_capacity(arguments.len()); + for argument_index in 0..arguments.len() { // The parameter types at the current argument index. let mut current_parameter_types = vec![]; for overload_index in &matching_overload_indexes[..=upto] { @@ -1904,8 +1894,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { struct ArgumentTypeChecker<'a, 'db> { db: &'db dyn Db, signature: &'a Signature<'db>, - arguments: &'a CallArguments<'a>, - argument_types: &'a [Type<'db>], + arguments: &'a CallArguments<'a, 'db>, argument_parameters: &'a [Option], parameter_tys: &'a mut [Option>], errors: &'a mut Vec>, @@ -1918,8 +1907,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { fn new( db: &'db dyn Db, signature: &'a Signature<'db>, - arguments: &'a CallArguments<'a>, - argument_types: &'a [Type<'db>], + arguments: &'a CallArguments<'a, 'db>, argument_parameters: &'a [Option], parameter_tys: &'a mut [Option>], errors: &'a mut Vec>, @@ -1928,7 +1916,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { db, signature, arguments, - argument_types, argument_parameters, parameter_tys, errors, @@ -1940,9 +1927,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { fn enumerate_argument_types( &self, ) -> impl Iterator, Argument<'a>, Type<'db>)> + 'a { - let mut iter = (self.arguments.iter()) - .zip(self.argument_types.iter().copied()) - .enumerate(); + let mut iter = self.arguments.iter().enumerate(); let mut num_synthetic_args = 0; std::iter::from_fn(move || { let (argument_index, (argument, argument_type)) = iter.next()?; @@ -1961,7 +1946,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { argument_index, adjusted_argument_index, argument, - argument_type, + argument_type.unwrap_or_else(Type::unknown), )) }) } @@ -2127,7 +2112,7 @@ impl<'db> Binding<'db> { pub(crate) fn match_parameters( &mut self, - arguments: &CallArguments<'_>, + arguments: &CallArguments<'_, 'db>, argument_forms: &mut [Option], conflicting_forms: &mut [bool], ) { @@ -2139,7 +2124,7 @@ impl<'db> Binding<'db> { conflicting_forms, &mut self.errors, ); - for (argument_index, argument) in arguments.iter().enumerate() { + for (argument_index, (argument, _)) in arguments.iter().enumerate() { match argument { Argument::Positional | Argument::Synthetic => { let _ = matcher.match_positional(argument_index, argument); @@ -2158,17 +2143,11 @@ impl<'db> Binding<'db> { self.argument_parameters = matcher.finish(); } - fn check_types( - &mut self, - db: &'db dyn Db, - arguments: &CallArguments<'_>, - argument_types: &[Type<'db>], - ) { + fn check_types(&mut self, db: &'db dyn Db, arguments: &CallArguments<'_, 'db>) { let mut checker = ArgumentTypeChecker::new( db, &self.signature, arguments, - argument_types, &self.argument_parameters, &mut self.parameter_tys, &mut self.errors, @@ -2210,7 +2189,7 @@ impl<'db> Binding<'db> { pub(crate) fn arguments_for_parameter<'a>( &'a self, - argument_types: &'a CallArgumentTypes<'a, 'db>, + argument_types: &'a CallArguments<'a, 'db>, parameter_index: usize, ) -> impl Iterator, Type<'db>)> + 'a { argument_types @@ -2219,7 +2198,9 @@ impl<'db> Binding<'db> { .filter(move |(_, argument_parameter)| { argument_parameter.is_some_and(|ap| ap == parameter_index) }) - .map(|(arg_and_type, _)| arg_and_type) + .map(|((argument, argument_type), _)| { + (argument, argument_type.unwrap_or_else(Type::unknown)) + }) } /// Mark this overload binding as an unmatched overload. diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index b2c2c0ff32..fa429c4dce 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -38,7 +38,7 @@ use crate::{ place_table, semantic_index, use_def_map, }, types::{ - CallArgumentTypes, CallError, CallErrorKind, MetaclassCandidate, UnionBuilder, UnionType, + CallArguments, CallError, CallErrorKind, MetaclassCandidate, UnionBuilder, UnionType, definition_expression_type, }, }; @@ -1210,7 +1210,7 @@ impl<'db> ClassLiteral<'db> { .to_specialized_instance(db, [KnownClass::Str.to_instance(db), Type::any()]); // TODO: Other keyword arguments? - let arguments = CallArgumentTypes::positional([name, bases, namespace]); + let arguments = CallArguments::positional([name, bases, namespace]); let return_ty_result = match metaclass.try_call(db, &arguments) { Ok(bindings) => Ok(bindings.return_type(db)), @@ -3312,7 +3312,7 @@ impl KnownClass { context: &InferContext<'db, '_>, index: &SemanticIndex<'db>, overload_binding: &Binding<'db>, - call_argument_types: &CallArgumentTypes<'_, 'db>, + call_argument_types: &CallArguments<'_, 'db>, call_expression: &ast::ExprCall, ) -> Option> { let db = context.db(); diff --git a/crates/ty_python_semantic/src/types/diagnostic.rs b/crates/ty_python_semantic/src/types/diagnostic.rs index e3cf81f75f..8ea50d48d0 100644 --- a/crates/ty_python_semantic/src/types/diagnostic.rs +++ b/crates/ty_python_semantic/src/types/diagnostic.rs @@ -2,7 +2,7 @@ use super::call::CallErrorKind; use super::context::InferContext; use super::mro::DuplicateBaseError; use super::{ - CallArgumentTypes, CallDunderError, ClassBase, ClassLiteral, KnownClass, + CallArguments, CallDunderError, ClassBase, ClassLiteral, KnownClass, add_inferred_python_version_hint_to_diagnostic, }; use crate::lint::{Level, LintRegistryBuilder, LintStatus}; @@ -2325,7 +2325,7 @@ pub(crate) fn report_invalid_or_unsupported_base( match base_type.try_call_dunder( db, "__mro_entries__", - CallArgumentTypes::positional([tuple_of_types]), + CallArguments::positional([tuple_of_types]), ) { Ok(ret) => { if ret.return_type(db).is_assignable_to(db, tuple_of_types) { diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index d0a1ee33ee..94066916e9 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -85,7 +85,7 @@ use crate::semantic_index::place::{ use crate::semantic_index::{ ApplicableConstraints, EagerSnapshotResult, SemanticIndex, place_table, semantic_index, }; -use crate::types::call::{Binding, Bindings, CallArgumentTypes, CallArguments, CallError}; +use crate::types::call::{Binding, Bindings, CallArguments, CallError}; use crate::types::class::{CodeGeneratorKind, MetaclassErrorKind, SliceLiteral}; use crate::types::diagnostic::{ self, CALL_NON_CALLABLE, CONFLICTING_DECLARATIONS, CONFLICTING_METACLASS, @@ -1965,9 +1965,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.infer_type_parameters(type_params); if let Some(arguments) = class.arguments.as_deref() { - let call_arguments = CallArguments::from_arguments(arguments); + let mut call_arguments = CallArguments::from_arguments(arguments); let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()]; - self.infer_argument_types(arguments, call_arguments, &argument_forms); + self.infer_argument_types(arguments, &mut call_arguments, &argument_forms); } } @@ -2374,7 +2374,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { for (decorator_ty, decorator_node) in decorator_types_and_nodes.iter().rev() { inferred_ty = match decorator_ty - .try_call(self.db(), &CallArgumentTypes::positional([inferred_ty])) + .try_call(self.db(), &CallArguments::positional([inferred_ty])) .map(|bindings| bindings.return_type(self.db())) { Ok(return_ty) => return_ty, @@ -3460,10 +3460,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let setattr_dunder_call_result = object_ty.try_call_dunder_with_policy( db, "__setattr__", - &mut CallArgumentTypes::positional([ - Type::string_literal(db, attribute), - value_ty, - ]), + &mut CallArguments::positional([Type::string_literal(db, attribute), value_ty]), MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK, ); @@ -3548,7 +3545,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let successful_call = meta_dunder_set .try_call( db, - &CallArgumentTypes::positional([ + &CallArguments::positional([ meta_attr_ty, object_ty, value_ty, @@ -3674,11 +3671,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let successful_call = meta_dunder_set .try_call( db, - &CallArgumentTypes::positional([ - meta_attr_ty, - object_ty, - value_ty, - ]), + &CallArguments::positional([meta_attr_ty, object_ty, value_ty]), ) .is_ok(); @@ -4099,7 +4092,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let call = target_type.try_call_dunder( db, op.in_place_dunder(), - CallArgumentTypes::positional([value_type]), + CallArguments::positional([value_type]), ); match call { @@ -4740,28 +4733,31 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { fn infer_argument_types<'a>( &mut self, ast_arguments: &ast::Arguments, - arguments: CallArguments<'a>, + arguments: &mut CallArguments<'a, 'db>, argument_forms: &[Option], - ) -> CallArgumentTypes<'a, 'db> { - let mut ast_arguments = ast_arguments.arguments_source_order(); - CallArgumentTypes::new(arguments, |index, _| { - let arg_or_keyword = ast_arguments - .next() - .expect("argument lists should have consistent lengths"); - match arg_or_keyword { + ) { + debug_assert!( + ast_arguments.len() == arguments.len() && arguments.len() == argument_forms.len() + ); + let iter = (arguments.iter_mut()) + .zip(argument_forms.iter().copied()) + .zip(ast_arguments.arguments_source_order()); + for (((_, argument_type), form), arg_or_keyword) in iter { + let ty = match arg_or_keyword { ast::ArgOrKeyword::Arg(arg) => match arg { ast::Expr::Starred(ast::ExprStarred { value, .. }) => { - let ty = self.infer_argument_type(value, argument_forms[index]); + let ty = self.infer_argument_type(value, form); self.store_expression_type(arg, ty); ty } - _ => self.infer_argument_type(arg, argument_forms[index]), + _ => self.infer_argument_type(arg, form), }, ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => { - self.infer_argument_type(value, argument_forms[index]) + self.infer_argument_type(value, form) } - } - }) + }; + *argument_type = Some(ty); + } } fn infer_argument_type( @@ -5450,7 +5446,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // We don't call `Type::try_call`, because we want to perform type inference on the // arguments after matching them to parameters, but before checking that the argument types // are assignable to any parameter annotations. - let call_arguments = CallArguments::from_arguments(arguments); + let mut call_arguments = CallArguments::from_arguments(arguments); let callable_type = self.infer_maybe_standalone_expression(func); @@ -5523,11 +5519,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .is_none_or(|enum_class| !class.is_subclass_of(self.db(), enum_class)) { let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()]; - let call_argument_types = - self.infer_argument_types(arguments, call_arguments, &argument_forms); + self.infer_argument_types(arguments, &mut call_arguments, &argument_forms); return callable_type - .try_call_constructor(self.db(), call_argument_types) + .try_call_constructor(self.db(), call_arguments) .unwrap_or_else(|err| { err.report_diagnostic(&self.context, callable_type, call_expression.into()); err.return_type() @@ -5538,10 +5533,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let bindings = callable_type .bindings(self.db()) .match_parameters(&call_arguments); - let call_argument_types = - self.infer_argument_types(arguments, call_arguments, &bindings.argument_forms); + self.infer_argument_types(arguments, &mut call_arguments, &bindings.argument_forms); - let mut bindings = match bindings.check_types(self.db(), &call_argument_types) { + let mut bindings = match bindings.check_types(self.db(), &call_arguments) { Ok(bindings) => bindings, Err(CallError(_, bindings)) => { bindings.report_diagnostics(&self.context, call_expression.into()); @@ -5574,7 +5568,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { &self.context, self.index, overload, - &call_argument_types, + &call_arguments, call_expression, ); if let Some(overridden_return) = overridden_return { @@ -6399,7 +6393,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { match operand_type.try_call_dunder( self.db(), unary_dunder_method, - CallArgumentTypes::none(), + CallArguments::none(), ) { Ok(outcome) => outcome.return_type(self.db()), Err(e) => { @@ -6773,7 +6767,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .try_call_dunder( self.db(), reflected_dunder, - CallArgumentTypes::positional([left_ty]), + CallArguments::positional([left_ty]), ) .map(|outcome| outcome.return_type(self.db())) .or_else(|_| { @@ -6781,7 +6775,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .try_call_dunder( self.db(), op.dunder(), - CallArgumentTypes::positional([right_ty]), + CallArguments::positional([right_ty]), ) .map(|outcome| outcome.return_type(self.db())) }) @@ -6793,7 +6787,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .try_call_dunder( self.db(), op.dunder(), - CallArgumentTypes::positional([right_ty]), + CallArguments::positional([right_ty]), ) .map(|outcome| outcome.return_type(self.db())) .ok(); @@ -6806,7 +6800,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .try_call_dunder( self.db(), op.reflected_dunder(), - CallArgumentTypes::positional([left_ty]), + CallArguments::positional([left_ty]), ) .map(|outcome| outcome.return_type(self.db())) .ok() @@ -7538,7 +7532,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // The following resource has details about the rich comparison algorithm: // https://snarky.ca/unravelling-rich-comparison-operators/ let call_dunder = |op: RichCompareOperator, left: Type<'db>, right: Type<'db>| { - left.try_call_dunder(db, op.dunder(), CallArgumentTypes::positional([right])) + left.try_call_dunder(db, op.dunder(), CallArguments::positional([right])) .map(|outcome| outcome.return_type(db)) .ok() }; @@ -7584,7 +7578,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Place::Type(contains_dunder, Boundness::Bound) => { // If `__contains__` is available, it is used directly for the membership test. contains_dunder - .try_call(db, &CallArgumentTypes::positional([right, left])) + .try_call(db, &CallArguments::positional([right, left])) .map(|bindings| bindings.return_type(db)) .ok() } @@ -7807,16 +7801,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let slice_node = subscript.slice.as_ref(); let call_argument_types = match slice_node { ast::Expr::Tuple(tuple) => { - let arguments = CallArgumentTypes::positional( + let arguments = CallArguments::positional( tuple.elts.iter().map(|elt| self.infer_type_expression(elt)), ); self.store_expression_type( slice_node, - TupleType::from_elements(self.db(), arguments.iter().map(|(_, ty)| ty)), + TupleType::from_elements(self.db(), arguments.iter_types()), ); arguments } - _ => CallArgumentTypes::positional([self.infer_type_expression(slice_node)]), + _ => CallArguments::positional([self.infer_type_expression(slice_node)]), }; let binding = Binding::single(value_ty, generic_context.signature(self.db())); let bindings = match Bindings::from(binding) @@ -8066,7 +8060,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { match value_ty.try_call_dunder( self.db(), "__getitem__", - CallArgumentTypes::positional([slice_ty]), + CallArguments::positional([slice_ty]), ) { Ok(outcome) => return outcome.return_type(self.db()), Err(err @ CallDunderError::PossiblyUnbound { .. }) => { @@ -8132,7 +8126,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { match ty.try_call( self.db(), - &CallArgumentTypes::positional([value_ty, slice_ty]), + &CallArguments::positional([value_ty, slice_ty]), ) { Ok(bindings) => return bindings.return_type(self.db()), Err(CallError(_, bindings)) => {