[ty] Refactor argument matching / type checking in call binding (#18997)

This PR extracts a lot of the complex logic in the `match_parameters`
and `check_types` methods of our call binding machinery into separate
helper types. This is setup for #18996, which will update this logic to
handle variadic arguments. To do so, it is helpful to have the
per-argument logic extracted into a method that we can call repeatedly
for each _element_ of a variadic argument.

This should be a pure refactoring, with no behavioral changes.
This commit is contained in:
Douglas Creager 2025-06-27 17:01:52 -04:00 committed by GitHub
parent c60e590b4c
commit caf3c916e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -26,7 +26,7 @@ use crate::types::function::{
DataclassTransformerParams, FunctionDecorators, FunctionType, KnownFunction, OverloadLiteral,
};
use crate::types::generics::{Specialization, SpecializationBuilder, SpecializationError};
use crate::types::signatures::{Parameter, ParameterForm};
use crate::types::signatures::{Parameter, ParameterForm, Parameters};
use crate::types::tuple::TupleType;
use crate::types::{
BoundMethodType, ClassLiteral, DataclassParams, KnownClass, KnownInstanceType,
@ -1754,6 +1754,334 @@ enum MatchingOverloadIndex {
Multiple(Vec<usize>),
}
struct ArgumentMatcher<'a, 'db> {
parameters: &'a Parameters<'db>,
argument_forms: &'a mut [Option<ParameterForm>],
conflicting_forms: &'a mut [bool],
errors: &'a mut Vec<BindingError<'db>>,
/// The parameter that each argument is matched with.
argument_parameters: Vec<Option<usize>>,
/// Whether each parameter has been matched with an argument.
parameter_matched: Vec<bool>,
next_positional: usize,
first_excess_positional: Option<usize>,
num_synthetic_args: usize,
}
impl<'a, 'db> ArgumentMatcher<'a, 'db> {
fn new(
arguments: &CallArguments,
parameters: &'a Parameters<'db>,
argument_forms: &'a mut [Option<ParameterForm>],
conflicting_forms: &'a mut [bool],
errors: &'a mut Vec<BindingError<'db>>,
) -> Self {
Self {
parameters,
argument_forms,
conflicting_forms,
errors,
argument_parameters: vec![None; arguments.len()],
parameter_matched: vec![false; parameters.len()],
next_positional: 0,
first_excess_positional: None,
num_synthetic_args: 0,
}
}
fn get_argument_index(&self, argument_index: usize) -> Option<usize> {
if argument_index >= self.num_synthetic_args {
// Adjust the argument index to skip synthetic args, which don't appear at the call
// site and thus won't be in the Call node arguments list.
Some(argument_index - self.num_synthetic_args)
} else {
// we are erroring on a synthetic argument, we'll just emit the diagnostic on the
// entire Call node, since there's no argument node for this argument at the call site
None
}
}
fn assign_argument(
&mut self,
argument_index: usize,
argument: Argument<'a>,
parameter_index: usize,
parameter: &Parameter<'db>,
positional: bool,
) {
if !matches!(argument, Argument::Synthetic) {
if let Some(existing) = self.argument_forms[argument_index - self.num_synthetic_args]
.replace(parameter.form)
{
if existing != parameter.form {
self.conflicting_forms[argument_index - self.num_synthetic_args] = true;
}
}
}
if self.parameter_matched[parameter_index] {
if !parameter.is_variadic() && !parameter.is_keyword_variadic() {
self.errors.push(BindingError::ParameterAlreadyAssigned {
argument_index: self.get_argument_index(argument_index),
parameter: ParameterContext::new(parameter, parameter_index, positional),
});
}
}
self.argument_parameters[argument_index] = Some(parameter_index);
self.parameter_matched[parameter_index] = true;
}
fn match_positional(
&mut self,
argument_index: usize,
argument: Argument<'a>,
) -> Result<(), ()> {
if matches!(argument, Argument::Synthetic) {
self.num_synthetic_args += 1;
}
let Some((parameter_index, parameter)) = self
.parameters
.get_positional(self.next_positional)
.map(|param| (self.next_positional, param))
.or_else(|| self.parameters.variadic())
else {
self.first_excess_positional.get_or_insert(argument_index);
self.next_positional += 1;
return Err(());
};
self.next_positional += 1;
self.assign_argument(
argument_index,
argument,
parameter_index,
parameter,
!parameter.is_variadic(),
);
Ok(())
}
fn match_keyword(
&mut self,
argument_index: usize,
argument: Argument<'a>,
name: &str,
) -> Result<(), ()> {
let Some((parameter_index, parameter)) = self
.parameters
.keyword_by_name(name)
.or_else(|| self.parameters.keyword_variadic())
else {
self.errors.push(BindingError::UnknownArgument {
argument_name: ast::name::Name::new(name),
argument_index: self.get_argument_index(argument_index),
});
return Err(());
};
self.assign_argument(argument_index, argument, parameter_index, parameter, false);
Ok(())
}
fn finish(self) -> Box<[Option<usize>]> {
if let Some(first_excess_argument_index) = self.first_excess_positional {
self.errors.push(BindingError::TooManyPositionalArguments {
first_excess_argument_index: self.get_argument_index(first_excess_argument_index),
expected_positional_count: self.parameters.positional().count(),
provided_positional_count: self.next_positional,
});
}
let mut missing = vec![];
for (index, matched) in self.parameter_matched.iter().copied().enumerate() {
if !matched {
let param = &self.parameters[index];
if param.is_variadic()
|| param.is_keyword_variadic()
|| param.default_type().is_some()
{
// variadic/keywords and defaulted arguments are not required
continue;
}
missing.push(ParameterContext::new(param, index, false));
}
}
if !missing.is_empty() {
self.errors.push(BindingError::MissingArguments {
parameters: ParameterContexts(missing),
});
}
self.argument_parameters.into_boxed_slice()
}
}
struct ArgumentTypeChecker<'a, 'db> {
db: &'db dyn Db,
signature: &'a Signature<'db>,
arguments: &'a CallArguments<'a>,
argument_types: &'a [Type<'db>],
argument_parameters: &'a [Option<usize>],
parameter_tys: &'a mut [Option<Type<'db>>],
errors: &'a mut Vec<BindingError<'db>>,
specialization: Option<Specialization<'db>>,
inherited_specialization: Option<Specialization<'db>>,
}
impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
fn new(
db: &'db dyn Db,
signature: &'a Signature<'db>,
arguments: &'a CallArguments<'a>,
argument_types: &'a [Type<'db>],
argument_parameters: &'a [Option<usize>],
parameter_tys: &'a mut [Option<Type<'db>>],
errors: &'a mut Vec<BindingError<'db>>,
) -> Self {
Self {
db,
signature,
arguments,
argument_types,
argument_parameters,
parameter_tys,
errors,
specialization: None,
inherited_specialization: None,
}
}
fn enumerate_argument_types(
&self,
) -> impl Iterator<Item = (usize, Option<usize>, Argument<'a>, Type<'db>)> + 'a {
let mut iter = (self.arguments.iter())
.zip(self.argument_types.iter().copied())
.enumerate();
let mut num_synthetic_args = 0;
std::iter::from_fn(move || {
let (argument_index, (argument, argument_type)) = iter.next()?;
let adjusted_argument_index = if matches!(argument, Argument::Synthetic) {
// If we are erroring on a synthetic argument, we'll just emit the
// diagnostic on the entire Call node, since there's no argument node for
// this argument at the call site
num_synthetic_args += 1;
None
} else {
// Adjust the argument index to skip synthetic args, which don't appear at
// the call site and thus won't be in the Call node arguments list.
Some(argument_index - num_synthetic_args)
};
Some((
argument_index,
adjusted_argument_index,
argument,
argument_type,
))
})
}
fn infer_specialization(&mut self) {
if self.signature.generic_context.is_none()
&& self.signature.inherited_generic_context.is_none()
{
return;
}
let parameters = self.signature.parameters();
let mut builder = SpecializationBuilder::new(self.db);
for (argument_index, adjusted_argument_index, _, argument_type) in
self.enumerate_argument_types()
{
let Some(parameter_index) = self.argument_parameters[argument_index] else {
// There was an error with argument when matching parameters, so don't bother
// type-checking it.
continue;
};
let parameter = &parameters[parameter_index];
let Some(expected_type) = parameter.annotated_type() else {
continue;
};
if let Err(error) = builder.infer(expected_type, argument_type) {
self.errors.push(BindingError::SpecializationError {
error,
argument_index: adjusted_argument_index,
});
}
}
self.specialization = self.signature.generic_context.map(|gc| builder.build(gc));
self.inherited_specialization = self.signature.inherited_generic_context.map(|gc| {
// The inherited generic context is used when inferring the specialization of a generic
// class from a constructor call. In this case (only), we promote any typevars that are
// inferred as a literal to the corresponding instance type.
builder
.build(gc)
.apply_type_mapping(self.db, &TypeMapping::PromoteLiterals)
});
}
fn check_argument_type(
&mut self,
argument_index: usize,
adjusted_argument_index: Option<usize>,
argument: Argument<'a>,
mut argument_type: Type<'db>,
) {
let Some(parameter_index) = self.argument_parameters[argument_index] else {
// There was an error with argument when matching parameters, so don't bother
// type-checking it.
return;
};
let parameters = self.signature.parameters();
let parameter = &parameters[parameter_index];
if let Some(mut expected_ty) = parameter.annotated_type() {
if let Some(specialization) = self.specialization {
argument_type = argument_type.apply_specialization(self.db, specialization);
expected_ty = expected_ty.apply_specialization(self.db, specialization);
}
if let Some(inherited_specialization) = self.inherited_specialization {
argument_type =
argument_type.apply_specialization(self.db, inherited_specialization);
expected_ty = expected_ty.apply_specialization(self.db, inherited_specialization);
}
if !argument_type.is_assignable_to(self.db, expected_ty) {
let positional = matches!(argument, Argument::Positional | Argument::Synthetic)
&& !parameter.is_variadic();
self.errors.push(BindingError::InvalidArgumentType {
parameter: ParameterContext::new(parameter, parameter_index, positional),
argument_index: adjusted_argument_index,
expected_ty,
provided_ty: argument_type,
});
}
}
// We still update the actual type of the parameter in this binding to match the
// argument, even if the argument type is not assignable to the expected parameter
// type.
if let Some(existing) = self.parameter_tys[parameter_index].replace(argument_type) {
// We already verified in `match_parameters` that we only match multiple arguments
// with variadic parameters.
let union = UnionType::from_elements(self.db, [existing, argument_type]);
self.parameter_tys[parameter_index] = Some(union);
}
}
fn check_argument_types(&mut self) {
for (argument_index, adjusted_argument_index, argument, argument_type) in
self.enumerate_argument_types()
{
self.check_argument_type(
argument_index,
adjusted_argument_index,
argument,
argument_type,
);
}
}
fn finish(self) -> (Option<Specialization<'db>>, Option<Specialization<'db>>) {
(self.specialization, self.inherited_specialization)
}
}
/// Binding information for one of the overloads of a callable.
#[derive(Debug)]
pub(crate) struct Binding<'db> {
@ -1817,115 +2145,30 @@ impl<'db> Binding<'db> {
conflicting_forms: &mut [bool],
) {
let parameters = self.signature.parameters();
// The parameter that each argument is matched with.
let mut argument_parameters = vec![None; arguments.len()];
// Whether each parameter has been matched with an argument.
let mut parameter_matched = vec![false; parameters.len()];
let mut next_positional = 0;
let mut first_excess_positional = None;
let mut num_synthetic_args = 0;
let get_argument_index = |argument_index: usize, num_synthetic_args: usize| {
if argument_index >= num_synthetic_args {
// Adjust the argument index to skip synthetic args, which don't appear at the call
// site and thus won't be in the Call node arguments list.
Some(argument_index - num_synthetic_args)
} else {
// we are erroring on a synthetic argument, we'll just emit the diagnostic on the
// entire Call node, since there's no argument node for this argument at the call site
None
}
};
let mut matcher = ArgumentMatcher::new(
arguments,
parameters,
argument_forms,
conflicting_forms,
&mut self.errors,
);
for (argument_index, argument) in arguments.iter().enumerate() {
let (index, parameter, positional) = match argument {
match argument {
Argument::Positional | Argument::Synthetic => {
if matches!(argument, Argument::Synthetic) {
num_synthetic_args += 1;
}
let Some((index, parameter)) = parameters
.get_positional(next_positional)
.map(|param| (next_positional, param))
.or_else(|| parameters.variadic())
else {
first_excess_positional.get_or_insert(argument_index);
next_positional += 1;
continue;
};
next_positional += 1;
(index, parameter, !parameter.is_variadic())
let _ = matcher.match_positional(argument_index, argument);
}
Argument::Keyword(name) => {
let Some((index, parameter)) = parameters
.keyword_by_name(name)
.or_else(|| parameters.keyword_variadic())
else {
self.errors.push(BindingError::UnknownArgument {
argument_name: ast::name::Name::new(name),
argument_index: get_argument_index(argument_index, num_synthetic_args),
});
continue;
};
(index, parameter, false)
let _ = matcher.match_keyword(argument_index, argument, name);
}
Argument::Variadic | Argument::Keywords => {
// TODO
continue;
}
};
if !matches!(argument, Argument::Synthetic) {
if let Some(existing) =
argument_forms[argument_index - num_synthetic_args].replace(parameter.form)
{
if existing != parameter.form {
conflicting_forms[argument_index - num_synthetic_args] = true;
}
}
}
if parameter_matched[index] {
if !parameter.is_variadic() && !parameter.is_keyword_variadic() {
self.errors.push(BindingError::ParameterAlreadyAssigned {
argument_index: get_argument_index(argument_index, num_synthetic_args),
parameter: ParameterContext::new(parameter, index, positional),
});
}
}
argument_parameters[argument_index] = Some(index);
parameter_matched[index] = true;
}
if let Some(first_excess_argument_index) = first_excess_positional {
self.errors.push(BindingError::TooManyPositionalArguments {
first_excess_argument_index: get_argument_index(
first_excess_argument_index,
num_synthetic_args,
),
expected_positional_count: parameters.positional().count(),
provided_positional_count: next_positional,
});
}
let mut missing = vec![];
for (index, matched) in parameter_matched.iter().copied().enumerate() {
if !matched {
let param = &parameters[index];
if param.is_variadic()
|| param.is_keyword_variadic()
|| param.default_type().is_some()
{
// variadic/keywords and defaulted arguments are not required
continue;
}
missing.push(ParameterContext::new(param, index, false));
}
}
if !missing.is_empty() {
self.errors.push(BindingError::MissingArguments {
parameters: ParameterContexts(missing),
});
}
self.return_ty = self.signature.return_ty.unwrap_or(Type::unknown());
self.argument_parameters = argument_parameters.into_boxed_slice();
self.parameter_tys = vec![None; parameters.len()].into_boxed_slice();
self.argument_parameters = matcher.finish();
}
fn check_types(
@ -1934,106 +2177,22 @@ impl<'db> Binding<'db> {
arguments: &CallArguments<'_>,
argument_types: &[Type<'db>],
) {
let mut num_synthetic_args = 0;
let get_argument_index = |argument_index: usize, num_synthetic_args: usize| {
if argument_index >= num_synthetic_args {
// Adjust the argument index to skip synthetic args, which don't appear at the call
// site and thus won't be in the Call node arguments list.
Some(argument_index - num_synthetic_args)
} else {
// we are erroring on a synthetic argument, we'll just emit the diagnostic on the
// entire Call node, since there's no argument node for this argument at the call site
None
}
};
let enumerate_argument_types = || {
arguments
.iter()
.zip(argument_types.iter().copied())
.enumerate()
};
let mut checker = ArgumentTypeChecker::new(
db,
&self.signature,
arguments,
argument_types,
&self.argument_parameters,
&mut self.parameter_tys,
&mut self.errors,
);
// If this overload is generic, first see if we can infer a specialization of the function
// from the arguments that were passed in.
let signature = &self.signature;
let parameters = signature.parameters();
if signature.generic_context.is_some() || signature.inherited_generic_context.is_some() {
let mut builder = SpecializationBuilder::new(db);
for (argument_index, (argument, argument_type)) in enumerate_argument_types() {
if matches!(argument, Argument::Synthetic) {
num_synthetic_args += 1;
}
let Some(parameter_index) = self.argument_parameters[argument_index] else {
// There was an error with argument when matching parameters, so don't bother
// type-checking it.
continue;
};
let parameter = &parameters[parameter_index];
let Some(expected_type) = parameter.annotated_type() else {
continue;
};
if let Err(error) = builder.infer(expected_type, argument_type) {
self.errors.push(BindingError::SpecializationError {
error,
argument_index: get_argument_index(argument_index, num_synthetic_args),
});
}
}
self.specialization = signature.generic_context.map(|gc| builder.build(gc));
self.inherited_specialization = signature.inherited_generic_context.map(|gc| {
// The inherited generic context is used when inferring the specialization of a
// generic class from a constructor call. In this case (only), we promote any
// typevars that are inferred as a literal to the corresponding instance type.
builder
.build(gc)
.apply_type_mapping(db, &TypeMapping::PromoteLiterals)
});
}
num_synthetic_args = 0;
for (argument_index, (argument, mut argument_type)) in enumerate_argument_types() {
if matches!(argument, Argument::Synthetic) {
num_synthetic_args += 1;
}
let Some(parameter_index) = self.argument_parameters[argument_index] else {
// There was an error with argument when matching parameters, so don't bother
// type-checking it.
continue;
};
let parameter = &parameters[parameter_index];
if let Some(mut expected_ty) = parameter.annotated_type() {
if let Some(specialization) = self.specialization {
argument_type = argument_type.apply_specialization(db, specialization);
expected_ty = expected_ty.apply_specialization(db, specialization);
}
if let Some(inherited_specialization) = self.inherited_specialization {
argument_type =
argument_type.apply_specialization(db, inherited_specialization);
expected_ty = expected_ty.apply_specialization(db, inherited_specialization);
}
if !argument_type.is_assignable_to(db, expected_ty) {
let positional = matches!(argument, Argument::Positional | Argument::Synthetic)
&& !parameter.is_variadic();
self.errors.push(BindingError::InvalidArgumentType {
parameter: ParameterContext::new(parameter, parameter_index, positional),
argument_index: get_argument_index(argument_index, num_synthetic_args),
expected_ty,
provided_ty: argument_type,
});
}
}
// We still update the actual type of the parameter in this binding to match the
// argument, even if the argument type is not assignable to the expected parameter
// type.
if let Some(existing) = self.parameter_tys[parameter_index].replace(argument_type) {
// We already verified in `match_parameters` that we only match multiple arguments
// with variadic parameters.
let union = UnionType::from_elements(db, [existing, argument_type]);
self.parameter_tys[parameter_index] = Some(union);
}
}
checker.infer_specialization();
checker.check_argument_types();
(self.specialization, self.inherited_specialization) = checker.finish();
if let Some(specialization) = self.specialization {
self.return_ty = self.return_ty.apply_specialization(db, specialization);
}