mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-29 13:24:57 +00:00
[ty] Add support for **kwargs
(#20430)
## Summary This PR adds support for unpacking `**kwargs` argument. This can be matched against any standard (positional or keyword), keyword-only, or keyword variadic parameter that haven't been matched yet. This PR also takes care of special casing `TypedDict` because the key names and the corresponding value type is known, so we can be more precise in our matching and type checking step. In the future, this special casing would be extended to include `ParamSpec` as well. Part of astral-sh/ty#247 ## Test Plan Add test cases for various scenarios.
This commit is contained in:
parent
6f2b60708e
commit
902b0b4ce9
5 changed files with 566 additions and 47 deletions
|
@ -43,14 +43,14 @@ impl<'a, 'db> CallArguments<'a, 'db> {
|
|||
pub(crate) fn from_arguments(
|
||||
db: &'db dyn Db,
|
||||
arguments: &'a ast::Arguments,
|
||||
mut infer_argument_type: impl FnMut(&ast::Expr, &ast::Expr) -> Type<'db>,
|
||||
mut infer_argument_type: impl FnMut(Option<&ast::Expr>, &ast::Expr) -> Type<'db>,
|
||||
) -> Self {
|
||||
arguments
|
||||
.arguments_source_order()
|
||||
.map(|arg_or_keyword| match arg_or_keyword {
|
||||
ast::ArgOrKeyword::Arg(arg) => match arg {
|
||||
ast::Expr::Starred(ast::ExprStarred { value, .. }) => {
|
||||
let ty = infer_argument_type(arg, value);
|
||||
let ty = infer_argument_type(Some(arg), value);
|
||||
let length = ty
|
||||
.try_iterate(db)
|
||||
.map(|tuple| tuple.len())
|
||||
|
@ -59,11 +59,12 @@ impl<'a, 'db> CallArguments<'a, 'db> {
|
|||
}
|
||||
_ => (Argument::Positional, None),
|
||||
},
|
||||
ast::ArgOrKeyword::Keyword(ast::Keyword { arg, .. }) => {
|
||||
ast::ArgOrKeyword::Keyword(ast::Keyword { arg, value, .. }) => {
|
||||
if let Some(arg) = arg {
|
||||
(Argument::Keyword(&arg.id), None)
|
||||
} else {
|
||||
(Argument::Keywords, None)
|
||||
let ty = infer_argument_type(None, value);
|
||||
(Argument::Keywords, Some(ty))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
@ -8,6 +8,7 @@ use std::fmt;
|
|||
|
||||
use itertools::{Either, Itertools};
|
||||
use ruff_db::parsed::parsed_module;
|
||||
use ruff_python_ast::name::Name;
|
||||
use smallvec::{SmallVec, smallvec, smallvec_inline};
|
||||
|
||||
use super::{Argument, CallArguments, CallError, CallErrorKind, InferContext, Signature, Type};
|
||||
|
@ -26,12 +27,13 @@ use crate::types::function::{
|
|||
DataclassTransformerParams, FunctionDecorators, FunctionType, KnownFunction, OverloadLiteral,
|
||||
};
|
||||
use crate::types::generics::{Specialization, SpecializationBuilder, SpecializationError};
|
||||
use crate::types::signatures::{Parameter, ParameterForm, Parameters};
|
||||
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
|
||||
use crate::types::tuple::{TupleLength, TupleType};
|
||||
use crate::types::{
|
||||
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
|
||||
KnownClass, KnownInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
|
||||
TypeAliasType, TypeMapping, UnionType, WrapperDescriptorKind, enums, ide_support, todo_type,
|
||||
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
|
||||
TrackedConstraintSet, TypeAliasType, TypeMapping, UnionType, WrapperDescriptorKind, enums,
|
||||
ide_support, todo_type,
|
||||
};
|
||||
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
|
||||
use ruff_python_ast::{self as ast, PythonVersion};
|
||||
|
@ -2088,6 +2090,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
|
|||
&mut self,
|
||||
argument_index: usize,
|
||||
argument: Argument<'a>,
|
||||
argument_type: Option<Type<'db>>,
|
||||
name: &str,
|
||||
) -> Result<(), ()> {
|
||||
let Some((parameter_index, parameter)) = self
|
||||
|
@ -2104,7 +2107,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
|
|||
self.assign_argument(
|
||||
argument_index,
|
||||
argument,
|
||||
None,
|
||||
argument_type,
|
||||
parameter_index,
|
||||
parameter,
|
||||
false,
|
||||
|
@ -2147,6 +2150,60 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn match_keyword_variadic(
|
||||
&mut self,
|
||||
db: &'db dyn Db,
|
||||
argument_index: usize,
|
||||
argument_type: Option<Type<'db>>,
|
||||
) {
|
||||
if let Some(Type::TypedDict(typed_dict)) = argument_type {
|
||||
// Special case TypedDict because we know which keys are present.
|
||||
for (name, field) in typed_dict.items(db) {
|
||||
let _ = self.match_keyword(
|
||||
argument_index,
|
||||
Argument::Keywords,
|
||||
Some(field.declared_ty),
|
||||
name.as_str(),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
let value_type = match argument_type.map(|ty| {
|
||||
ty.member_lookup_with_policy(
|
||||
db,
|
||||
Name::new_static("__getitem__"),
|
||||
MemberLookupPolicy::NO_INSTANCE_FALLBACK,
|
||||
)
|
||||
.place
|
||||
}) {
|
||||
Some(Place::Type(keys_method, Boundness::Bound)) => keys_method
|
||||
.try_call(db, &CallArguments::positional([Type::unknown()]))
|
||||
.ok()
|
||||
.map_or_else(Type::unknown, |bindings| bindings.return_type(db)),
|
||||
_ => Type::unknown(),
|
||||
};
|
||||
|
||||
for (parameter_index, parameter) in self.parameters.iter().enumerate() {
|
||||
if self.parameter_matched[parameter_index] && !parameter.is_keyword_variadic() {
|
||||
continue;
|
||||
}
|
||||
if matches!(
|
||||
parameter.kind(),
|
||||
ParameterKind::PositionalOnly { .. } | ParameterKind::Variadic { .. }
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
self.assign_argument(
|
||||
argument_index,
|
||||
Argument::Keywords,
|
||||
Some(value_type),
|
||||
parameter_index,
|
||||
parameter,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn finish(self) -> Box<[MatchedArgument<'db>]> {
|
||||
if let Some(first_excess_argument_index) = self.first_excess_positional {
|
||||
self.errors.push(BindingError::TooManyPositionalArguments {
|
||||
|
@ -2335,47 +2392,159 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
for (argument_index, adjusted_argument_index, argument, argument_type) in
|
||||
self.enumerate_argument_types()
|
||||
{
|
||||
// If the argument isn't splatted, just check its type directly.
|
||||
let Argument::Variadic(_) = argument else {
|
||||
for parameter_index in &self.argument_matches[argument_index].parameters {
|
||||
self.check_argument_type(
|
||||
adjusted_argument_index,
|
||||
argument,
|
||||
argument_type,
|
||||
*parameter_index,
|
||||
);
|
||||
match argument {
|
||||
Argument::Variadic(_) => self.check_variadic_argument_type(
|
||||
argument_index,
|
||||
adjusted_argument_index,
|
||||
argument,
|
||||
argument_type,
|
||||
),
|
||||
Argument::Keywords => self.check_keyword_variadic_argument_type(
|
||||
argument_index,
|
||||
adjusted_argument_index,
|
||||
argument,
|
||||
argument_type,
|
||||
),
|
||||
_ => {
|
||||
// If the argument isn't splatted, just check its type directly.
|
||||
for parameter_index in &self.argument_matches[argument_index].parameters {
|
||||
self.check_argument_type(
|
||||
adjusted_argument_index,
|
||||
argument,
|
||||
argument_type,
|
||||
*parameter_index,
|
||||
);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the argument is splatted, convert its type into a tuple describing the splatted
|
||||
// elements. For tuples, we don't have to do anything! For other types, we treat it as
|
||||
// an iterator, and create a homogeneous tuple of its output type, since we don't know
|
||||
// how many elements the iterator will produce.
|
||||
let argument_types = argument_type.iterate(self.db);
|
||||
fn check_variadic_argument_type(
|
||||
&mut self,
|
||||
argument_index: usize,
|
||||
adjusted_argument_index: Option<usize>,
|
||||
argument: Argument<'a>,
|
||||
argument_type: Type<'db>,
|
||||
) {
|
||||
// If the argument is splatted, convert its type into a tuple describing the splatted
|
||||
// elements. For tuples, we don't have to do anything! For other types, we treat it as
|
||||
// an iterator, and create a homogeneous tuple of its output type, since we don't know
|
||||
// how many elements the iterator will produce.
|
||||
let argument_types = argument_type.iterate(self.db);
|
||||
|
||||
// Resize the tuple of argument types to line up with the number of parameters this
|
||||
// argument was matched against. If parameter matching succeeded, then we can (TODO:
|
||||
// should be able to, see above) guarantee that all of the required elements of the
|
||||
// splatted tuple will have been matched with a parameter. But if parameter matching
|
||||
// failed, there might be more required elements. That means we can't use
|
||||
// TupleLength::Fixed below, because we would otherwise get a "too many values" error
|
||||
// when parameter matching failed.
|
||||
let desired_size =
|
||||
TupleLength::Variable(self.argument_matches[argument_index].parameters.len(), 0);
|
||||
let argument_types = argument_types
|
||||
.resize(self.db, desired_size)
|
||||
.expect("argument type should be consistent with its arity");
|
||||
// Resize the tuple of argument types to line up with the number of parameters this
|
||||
// argument was matched against. If parameter matching succeeded, then we can (TODO:
|
||||
// should be able to, see above) guarantee that all of the required elements of the
|
||||
// splatted tuple will have been matched with a parameter. But if parameter matching
|
||||
// failed, there might be more required elements. That means we can't use
|
||||
// TupleLength::Fixed below, because we would otherwise get a "too many values" error
|
||||
// when parameter matching failed.
|
||||
let desired_size =
|
||||
TupleLength::Variable(self.argument_matches[argument_index].parameters.len(), 0);
|
||||
let argument_types = argument_types
|
||||
.resize(self.db, desired_size)
|
||||
.expect("argument type should be consistent with its arity");
|
||||
|
||||
// Check the types by zipping through the splatted argument types and their matched
|
||||
// parameters.
|
||||
for (argument_type, parameter_index) in (argument_types.all_elements())
|
||||
// Check the types by zipping through the splatted argument types and their matched
|
||||
// parameters.
|
||||
for (argument_type, parameter_index) in
|
||||
(argument_types.all_elements()).zip(&self.argument_matches[argument_index].parameters)
|
||||
{
|
||||
self.check_argument_type(
|
||||
adjusted_argument_index,
|
||||
argument,
|
||||
*argument_type,
|
||||
*parameter_index,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn check_keyword_variadic_argument_type(
|
||||
&mut self,
|
||||
argument_index: usize,
|
||||
adjusted_argument_index: Option<usize>,
|
||||
argument: Argument<'a>,
|
||||
argument_type: Type<'db>,
|
||||
) {
|
||||
if let Type::TypedDict(typed_dict) = argument_type {
|
||||
for (argument_type, parameter_index) in typed_dict
|
||||
.items(self.db)
|
||||
.iter()
|
||||
.map(|(_, field)| field.declared_ty)
|
||||
.zip(&self.argument_matches[argument_index].parameters)
|
||||
{
|
||||
self.check_argument_type(
|
||||
adjusted_argument_index,
|
||||
argument,
|
||||
*argument_type,
|
||||
argument_type,
|
||||
*parameter_index,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// TODO: Instead of calling the `keys` and `__getitem__` methods, we should instead
|
||||
// get the constraints which satisfies the `SupportsKeysAndGetItem` protocol i.e., the
|
||||
// key and value type.
|
||||
let key_type = match argument_type
|
||||
.member_lookup_with_policy(
|
||||
self.db,
|
||||
Name::new_static("keys"),
|
||||
MemberLookupPolicy::NO_INSTANCE_FALLBACK,
|
||||
)
|
||||
.place
|
||||
{
|
||||
Place::Type(keys_method, Boundness::Bound) => keys_method
|
||||
.try_call(self.db, &CallArguments::none())
|
||||
.ok()
|
||||
.and_then(|bindings| {
|
||||
Some(
|
||||
bindings
|
||||
.return_type(self.db)
|
||||
.try_iterate(self.db)
|
||||
.ok()?
|
||||
.homogeneous_element_type(self.db),
|
||||
)
|
||||
}),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let Some(key_type) = key_type else {
|
||||
self.errors.push(BindingError::KeywordsNotAMapping {
|
||||
argument_index: adjusted_argument_index,
|
||||
provided_ty: argument_type,
|
||||
});
|
||||
return;
|
||||
};
|
||||
|
||||
if !key_type.is_assignable_to(self.db, KnownClass::Str.to_instance(self.db)) {
|
||||
self.errors.push(BindingError::InvalidKeyType {
|
||||
argument_index: adjusted_argument_index,
|
||||
provided_ty: key_type,
|
||||
});
|
||||
}
|
||||
|
||||
let value_type = match argument_type
|
||||
.member_lookup_with_policy(
|
||||
self.db,
|
||||
Name::new_static("__getitem__"),
|
||||
MemberLookupPolicy::NO_INSTANCE_FALLBACK,
|
||||
)
|
||||
.place
|
||||
{
|
||||
Place::Type(keys_method, Boundness::Bound) => keys_method
|
||||
.try_call(self.db, &CallArguments::positional([Type::unknown()]))
|
||||
.ok()
|
||||
.map_or_else(Type::unknown, |bindings| bindings.return_type(self.db)),
|
||||
_ => Type::unknown(),
|
||||
};
|
||||
|
||||
for (argument_type, parameter_index) in
|
||||
std::iter::repeat(value_type).zip(&self.argument_matches[argument_index].parameters)
|
||||
{
|
||||
self.check_argument_type(
|
||||
adjusted_argument_index,
|
||||
Argument::Keywords,
|
||||
argument_type,
|
||||
*parameter_index,
|
||||
);
|
||||
}
|
||||
|
@ -2493,24 +2662,27 @@ impl<'db> Binding<'db> {
|
|||
let parameters = self.signature.parameters();
|
||||
let mut matcher =
|
||||
ArgumentMatcher::new(arguments, parameters, argument_forms, &mut self.errors);
|
||||
let mut keywords_arguments = vec![];
|
||||
for (argument_index, (argument, argument_type)) in arguments.iter().enumerate() {
|
||||
match argument {
|
||||
Argument::Positional | Argument::Synthetic => {
|
||||
let _ = matcher.match_positional(argument_index, argument, None);
|
||||
}
|
||||
Argument::Keyword(name) => {
|
||||
let _ = matcher.match_keyword(argument_index, argument, name);
|
||||
let _ = matcher.match_keyword(argument_index, argument, None, name);
|
||||
}
|
||||
Argument::Variadic(length) => {
|
||||
let _ =
|
||||
matcher.match_variadic(db, argument_index, argument, argument_type, length);
|
||||
}
|
||||
Argument::Keywords => {
|
||||
// TODO
|
||||
continue;
|
||||
keywords_arguments.push((argument_index, argument_type));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (keywords_index, keywords_type) in keywords_arguments {
|
||||
matcher.match_keyword_variadic(db, keywords_index, keywords_type);
|
||||
}
|
||||
self.return_ty = self.signature.return_ty.unwrap_or(Type::unknown());
|
||||
self.parameter_tys = vec![None; parameters.len()].into_boxed_slice();
|
||||
self.argument_matches = matcher.finish();
|
||||
|
@ -2874,6 +3046,15 @@ pub(crate) enum BindingError<'db> {
|
|||
expected_ty: Type<'db>,
|
||||
provided_ty: Type<'db>,
|
||||
},
|
||||
/// The type of the keyword-variadic argument's key is not `str`.
|
||||
InvalidKeyType {
|
||||
argument_index: Option<usize>,
|
||||
provided_ty: Type<'db>,
|
||||
},
|
||||
KeywordsNotAMapping {
|
||||
argument_index: Option<usize>,
|
||||
provided_ty: Type<'db>,
|
||||
},
|
||||
/// One or more required parameters (that is, with no default) is not supplied by any argument.
|
||||
MissingArguments {
|
||||
parameters: ParameterContexts,
|
||||
|
@ -3013,6 +3194,45 @@ impl<'db> BindingError<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
Self::InvalidKeyType {
|
||||
argument_index,
|
||||
provided_ty,
|
||||
} => {
|
||||
let range = Self::get_node(node, *argument_index);
|
||||
let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, range) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let provided_ty_display = provided_ty.display(context.db());
|
||||
let mut diag = builder.into_diagnostic(
|
||||
"Argument expression after ** must be a mapping with `str` key type",
|
||||
);
|
||||
diag.set_primary_message(format_args!("Found `{provided_ty_display}`"));
|
||||
|
||||
if let Some(union_diag) = union_diag {
|
||||
union_diag.add_union_context(context.db(), &mut diag);
|
||||
}
|
||||
}
|
||||
|
||||
Self::KeywordsNotAMapping {
|
||||
argument_index,
|
||||
provided_ty,
|
||||
} => {
|
||||
let range = Self::get_node(node, *argument_index);
|
||||
let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, range) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let provided_ty_display = provided_ty.display(context.db());
|
||||
let mut diag =
|
||||
builder.into_diagnostic("Argument expression after ** must be a mapping type");
|
||||
diag.set_primary_message(format_args!("Found `{provided_ty_display}`"));
|
||||
|
||||
if let Some(union_diag) = union_diag {
|
||||
union_diag.add_union_context(context.db(), &mut diag);
|
||||
}
|
||||
}
|
||||
|
||||
Self::TooManyPositionalArguments {
|
||||
first_excess_argument_index,
|
||||
expected_positional_count,
|
||||
|
|
|
@ -1731,7 +1731,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
let mut call_arguments =
|
||||
CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| {
|
||||
let ty = self.infer_expression(splatted_value, TypeContext::default());
|
||||
self.store_expression_type(argument, ty);
|
||||
if let Some(argument) = argument {
|
||||
self.store_expression_type(argument, ty);
|
||||
}
|
||||
ty
|
||||
});
|
||||
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
|
||||
|
@ -4944,7 +4946,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
for (((_, argument_type), form), arg_or_keyword) in iter {
|
||||
let argument = match arg_or_keyword {
|
||||
// We already inferred the type of splatted arguments.
|
||||
ast::ArgOrKeyword::Arg(ast::Expr::Starred(_)) => continue,
|
||||
ast::ArgOrKeyword::Arg(ast::Expr::Starred(_))
|
||||
| ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }) => continue,
|
||||
ast::ArgOrKeyword::Arg(arg) => arg,
|
||||
ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value,
|
||||
};
|
||||
|
@ -5787,7 +5790,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
let mut call_arguments =
|
||||
CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| {
|
||||
let ty = self.infer_expression(splatted_value, TypeContext::default());
|
||||
self.store_expression_type(argument, ty);
|
||||
if let Some(argument) = argument {
|
||||
self.store_expression_type(argument, ty);
|
||||
}
|
||||
ty
|
||||
});
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue