diff --git a/crates/ty_python_semantic/resources/mdtest/call/overloads.md b/crates/ty_python_semantic/resources/mdtest/call/overloads.md index fcb23dc077..bfb516f026 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/call/overloads.md @@ -139,8 +139,7 @@ reveal_type(f(A())) # revealed: A reveal_type(f(*(A(),))) # revealed: A reveal_type(f(B())) # revealed: A -# TODO: revealed: A -reveal_type(f(*(B(),))) # revealed: Unknown +reveal_type(f(*(B(),))) # revealed: A # But, in this case, the arity check filters out the first overload, so we only have one match: reveal_type(f(B(), 1)) # revealed: B @@ -551,16 +550,13 @@ from overloaded import MyEnumSubclass, ActualEnum, f def _(actual_enum: ActualEnum, my_enum_instance: MyEnumSubclass): reveal_type(f(actual_enum)) # revealed: Both - # TODO: revealed: Both - reveal_type(f(*(actual_enum,))) # revealed: Unknown + reveal_type(f(*(actual_enum,))) # revealed: Both reveal_type(f(ActualEnum.A)) # revealed: OnlyA - # TODO: revealed: OnlyA - reveal_type(f(*(ActualEnum.A,))) # revealed: Unknown + reveal_type(f(*(ActualEnum.A,))) # revealed: OnlyA reveal_type(f(ActualEnum.B)) # revealed: OnlyB - # TODO: revealed: OnlyB - reveal_type(f(*(ActualEnum.B,))) # revealed: Unknown + reveal_type(f(*(ActualEnum.B,))) # revealed: OnlyB reveal_type(f(my_enum_instance)) # revealed: MyEnumSubclass reveal_type(f(*(my_enum_instance,))) # revealed: MyEnumSubclass @@ -1097,12 +1093,10 @@ reveal_type(f(*(1,))) # revealed: str def _(list_int: list[int], list_any: list[Any]): reveal_type(f(list_int)) # revealed: int - # TODO: revealed: int - reveal_type(f(*(list_int,))) # revealed: Unknown + reveal_type(f(*(list_int,))) # revealed: int reveal_type(f(list_any)) # revealed: int - # TODO: revealed: int - reveal_type(f(*(list_any,))) # revealed: Unknown + reveal_type(f(*(list_any,))) # revealed: int ``` ### Single list argument (ambiguous) @@ -1136,8 +1130,7 @@ def _(list_int: list[int], list_any: list[Any]): # All materializations of `list[int]` are assignable to `list[int]`, so it matches the first # overload. reveal_type(f(list_int)) # revealed: int - # TODO: revealed: int - reveal_type(f(*(list_int,))) # revealed: Unknown + reveal_type(f(*(list_int,))) # revealed: int # All materializations of `list[Any]` are assignable to `list[int]` and `list[Any]`, but the # return type of first and second overloads are not equivalent, so the overload matching @@ -1170,25 +1163,21 @@ reveal_type(f("a")) # revealed: str reveal_type(f(*("a",))) # revealed: str reveal_type(f((1, "b"))) # revealed: int -# TODO: revealed: int -reveal_type(f(*((1, "b"),))) # revealed: Unknown +reveal_type(f(*((1, "b"),))) # revealed: int reveal_type(f((1, 2))) # revealed: int -# TODO: revealed: int -reveal_type(f(*((1, 2),))) # revealed: Unknown +reveal_type(f(*((1, 2),))) # revealed: int def _(int_str: tuple[int, str], int_any: tuple[int, Any], any_any: tuple[Any, Any]): # All materializations are assignable to first overload, so second and third overloads are # eliminated reveal_type(f(int_str)) # revealed: int - # TODO: revealed: int - reveal_type(f(*(int_str,))) # revealed: Unknown + reveal_type(f(*(int_str,))) # revealed: int # All materializations are assignable to second overload, so the third overload is eliminated; # the return type of first and second overload is equivalent reveal_type(f(int_any)) # revealed: int - # TODO: revealed: int - reveal_type(f(*(int_any,))) # revealed: Unknown + reveal_type(f(*(int_any,))) # revealed: int # All materializations of `tuple[Any, Any]` are assignable to the parameters of all the # overloads, but the return types aren't equivalent, so the overload matching is ambiguous @@ -1266,26 +1255,22 @@ def _(list_int: list[int], list_any: list[Any], int_str: tuple[int, str], int_an # All materializations of both argument types are assignable to the first overload, so the # second and third overloads are filtered out reveal_type(f(list_int, int_str)) # revealed: A - # TODO: revealed: A - reveal_type(f(*(list_int, int_str))) # revealed: Unknown + reveal_type(f(*(list_int, int_str))) # revealed: A # All materialization of first argument is assignable to first overload and for the second # argument, they're assignable to the second overload, so the third overload is filtered out reveal_type(f(list_int, int_any)) # revealed: A - # TODO: revealed: A - reveal_type(f(*(list_int, int_any))) # revealed: Unknown + reveal_type(f(*(list_int, int_any))) # revealed: A # All materialization of first argument is assignable to second overload and for the second # argument, they're assignable to the first overload, so the third overload is filtered out reveal_type(f(list_any, int_str)) # revealed: A - # TODO: revealed: A - reveal_type(f(*(list_any, int_str))) # revealed: Unknown + reveal_type(f(*(list_any, int_str))) # revealed: A # All materializations of both arguments are assignable to the second overload, so the third # overload is filtered out reveal_type(f(list_any, int_any)) # revealed: A - # TODO: revealed: A - reveal_type(f(*(list_any, int_any))) # revealed: Unknown + reveal_type(f(*(list_any, int_any))) # revealed: A # All materializations of first argument is assignable to the second overload and for the second # argument, they're assignable to the third overload, so no overloads are filtered out; the @@ -1316,8 +1301,7 @@ from overloaded import f def _(literal: LiteralString, string: str, any: Any): reveal_type(f(literal)) # revealed: LiteralString - # TODO: revealed: LiteralString - reveal_type(f(*(literal,))) # revealed: Unknown + reveal_type(f(*(literal,))) # revealed: LiteralString reveal_type(f(string)) # revealed: str reveal_type(f(*(string,))) # revealed: str @@ -1355,12 +1339,10 @@ from overloaded import f def _(list_int: list[int], list_str: list[str], list_any: list[Any], any: Any): reveal_type(f(list_int)) # revealed: A - # TODO: revealed: A - reveal_type(f(*(list_int,))) # revealed: Unknown + reveal_type(f(*(list_int,))) # revealed: A reveal_type(f(list_str)) # revealed: str - # TODO: Should be `str` - reveal_type(f(*(list_str,))) # revealed: Unknown + reveal_type(f(*(list_str,))) # revealed: str reveal_type(f(list_any)) # revealed: Unknown reveal_type(f(*(list_any,))) # revealed: Unknown @@ -1561,12 +1543,10 @@ def _(any: Any): reveal_type(f(*(any,), flag=False)) # revealed: str def _(args: tuple[Any, Literal[True]]): - # TODO: revealed: int - reveal_type(f(*args)) # revealed: Unknown + reveal_type(f(*args)) # revealed: int def _(args: tuple[Any, Literal[False]]): - # TODO: revealed: str - reveal_type(f(*args)) # revealed: Unknown + reveal_type(f(*args)) # revealed: str ``` ### Argument type expansion diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 59a38e303b..841159221b 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -32,7 +32,7 @@ use crate::types::tuple::{TupleLength, TupleType}; use crate::types::{ BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType, - TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionType, + TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, todo_type, }; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; @@ -1588,6 +1588,14 @@ impl<'db> CallableBinding<'db> { arguments: &CallArguments<'_, 'db>, matching_overload_indexes: &[usize], ) { + // The maximum number of parameters across all the overloads that are being considered + // for filtering. + let max_parameter_count = matching_overload_indexes + .iter() + .map(|&index| self.overloads[index].signature.parameters().len()) + .max() + .unwrap_or(0); + // These are the parameter indexes that matches the arguments that participate in the // filtering process. // @@ -1595,41 +1603,67 @@ impl<'db> CallableBinding<'db> { // gradual equivalent to the parameter types at the same index for other overloads. let mut participating_parameter_indexes = HashSet::new(); - // These only contain the top materialized argument types for the corresponding - // participating parameter indexes. - let mut top_materialized_argument_types = vec![]; + // The parameter types at each index for the first overload containing a parameter at + // that index. + let mut first_parameter_types: Vec>> = vec![None; max_parameter_count]; - for (argument_index, argument_type) in arguments.iter_types().enumerate() { - let mut first_parameter_type: Option> = None; - let mut participating_parameter_index = None; - - 'overload: for overload_index in matching_overload_indexes { + for argument_index in 0..arguments.len() { + for overload_index in matching_overload_indexes { let overload = &self.overloads[*overload_index]; - for parameter_index in &overload.argument_matches[argument_index].parameters { + for ¶meter_index in &overload.argument_matches[argument_index].parameters { // TODO: For an unannotated `self` / `cls` parameter, the type should be // `typing.Self` / `type[typing.Self]` - let current_parameter_type = overload.signature.parameters()[*parameter_index] + let current_parameter_type = overload.signature.parameters()[parameter_index] .annotated_type() .unwrap_or(Type::unknown()); + let first_parameter_type = &mut first_parameter_types[parameter_index]; if let Some(first_parameter_type) = first_parameter_type { if !first_parameter_type.is_equivalent_to(db, current_parameter_type) { - participating_parameter_index = Some(*parameter_index); - break 'overload; + participating_parameter_indexes.insert(parameter_index); } } else { - first_parameter_type = Some(current_parameter_type); + *first_parameter_type = Some(current_parameter_type); } } } + } - if let Some(parameter_index) = participating_parameter_index { - participating_parameter_indexes.insert(parameter_index); - top_materialized_argument_types.push(argument_type.top_materialization(db)); + let mut union_argument_type_builders = std::iter::repeat_with(|| UnionBuilder::new(db)) + .take(max_parameter_count) + .collect::>(); + + for (argument_index, argument_type) in arguments.iter_types().enumerate() { + for overload_index in matching_overload_indexes { + let overload = &self.overloads[*overload_index]; + for (parameter_index, variadic_argument_type) in + overload.argument_matches[argument_index].iter() + { + if !participating_parameter_indexes.contains(¶meter_index) { + continue; + } + union_argument_type_builders[parameter_index].add_in_place( + variadic_argument_type + .unwrap_or(argument_type) + .top_materialization(db), + ); + } } } - let top_materialized_argument_type = - Type::heterogeneous_tuple(db, top_materialized_argument_types); + // These only contain the top materialized argument types for the corresponding + // participating parameter indexes. + let top_materialized_argument_type = Type::heterogeneous_tuple( + db, + union_argument_type_builders + .into_iter() + .filter_map(|builder| { + if builder.is_empty() { + None + } else { + Some(builder.build()) + } + }), + ); // A flag to indicate whether we've found the overload that makes the remaining overloads // unmatched for the given argument types. @@ -1640,15 +1674,22 @@ impl<'db> CallableBinding<'db> { self.overloads[*current_index].mark_as_unmatched_overload(); continue; } - let mut parameter_types = Vec::with_capacity(arguments.len()); + + let mut union_parameter_types = std::iter::repeat_with(|| UnionBuilder::new(db)) + .take(max_parameter_count) + .collect::>(); + + // The number of parameters that have been skipped because they don't participate in + // the filtering process. This is used to make sure the types are added to the + // corresponding parameter index in `union_parameter_types`. + let mut skipped_parameters = 0; + for argument_index in 0..arguments.len() { - // The parameter types at the current argument index. - let mut current_parameter_types = vec![]; for overload_index in &matching_overload_indexes[..=upto] { let overload = &self.overloads[*overload_index]; for parameter_index in &overload.argument_matches[argument_index].parameters { if !participating_parameter_indexes.contains(parameter_index) { - // This parameter doesn't participate in the filtering process. + skipped_parameters += 1; continue; } // TODO: For an unannotated `self` / `cls` parameter, the type should be @@ -1664,17 +1705,24 @@ impl<'db> CallableBinding<'db> { parameter_type = parameter_type.apply_specialization(db, inherited_specialization); } - current_parameter_types.push(parameter_type); + union_parameter_types[parameter_index.saturating_sub(skipped_parameters)] + .add_in_place(parameter_type); } } - if current_parameter_types.is_empty() { - continue; - } - parameter_types.push(UnionType::from_elements(db, current_parameter_types)); } - if top_materialized_argument_type - .is_assignable_to(db, Type::heterogeneous_tuple(db, parameter_types)) - { + + let parameter_types = Type::heterogeneous_tuple( + db, + union_parameter_types.into_iter().filter_map(|builder| { + if builder.is_empty() { + None + } else { + Some(builder.build()) + } + }), + ); + + if top_materialized_argument_type.is_assignable_to(db, parameter_types) { filter_remaining_overloads = true; } }