diff --git a/crates/ty_python_semantic/resources/mdtest/call/overloads.md b/crates/ty_python_semantic/resources/mdtest/call/overloads.md index d6fbf08b3a..c2cc47d1ee 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/call/overloads.md @@ -660,9 +660,9 @@ class Foo: ```py from overloaded import A, B, C, Foo, f -from typing_extensions import reveal_type +from typing_extensions import Any, reveal_type -def _(ab: A | B, a=1): +def _(ab: A | B, a: int | Any): reveal_type(f(a1=a, a2=a, a3=a)) # revealed: C reveal_type(f(A(), a1=a, a2=a, a3=a)) # revealed: A reveal_type(f(B(), a1=a, a2=a, a3=a)) # revealed: B @@ -750,7 +750,7 @@ def _(ab: A | B, a=1): ) ) -def _(foo: Foo, ab: A | B, a=1): +def _(foo: Foo, ab: A | B, a: int | Any): reveal_type(foo.f(a1=a, a2=a, a3=a)) # revealed: C reveal_type(foo.f(A(), a1=a, a2=a, a3=a)) # revealed: A reveal_type(foo.f(B(), a1=a, a2=a, a3=a)) # revealed: B @@ -831,6 +831,76 @@ def _(foo: Foo, ab: A | B, a=1): ) ``` +### Optimization: Limit expansion size + + + +To prevent combinatorial explosion, ty limits the number of argument lists created by expanding a +single argument. + +`overloaded.pyi`: + +```pyi +from typing import overload + +class A: ... +class B: ... +class C: ... + +@overload +def f() -> None: ... +@overload +def f(**kwargs: int) -> C: ... +@overload +def f(x: A, /, **kwargs: int) -> A: ... +@overload +def f(x: B, /, **kwargs: int) -> B: ... +``` + +```py +from overloaded import A, B, f +from typing_extensions import reveal_type + +def _(a: int | None): + reveal_type( + # error: [no-matching-overload] + # revealed: Unknown + f( + A(), + a1=a, + a2=a, + a3=a, + a4=a, + a5=a, + a6=a, + a7=a, + a8=a, + a9=a, + a10=a, + a11=a, + a12=a, + a13=a, + a14=a, + a15=a, + a16=a, + a17=a, + a18=a, + a19=a, + a20=a, + a21=a, + a22=a, + a23=a, + a24=a, + a25=a, + a26=a, + a27=a, + a28=a, + a29=a, + a30=a, + ) + ) +``` + ## Filtering based on `Any` / `Unknown` This is the step 5 of the overload call evaluation algorithm which specifies that: diff --git a/crates/ty_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Argument_type_expans…_-_Optimization___Limit_…_(cd61048adbc17331).snap b/crates/ty_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Argument_type_expans…_-_Optimization___Limit_…_(cd61048adbc17331).snap new file mode 100644 index 0000000000..cf278f4328 --- /dev/null +++ b/crates/ty_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Argument_type_expans…_-_Optimization___Limit_…_(cd61048adbc17331).snap @@ -0,0 +1,183 @@ +--- +source: crates/ty_test/src/lib.rs +expression: snapshot +--- +--- +mdtest name: overloads.md - Overloads - Argument type expansion - Optimization: Limit expansion size +mdtest path: crates/ty_python_semantic/resources/mdtest/call/overloads.md +--- + +# Python source files + +## overloaded.pyi + +``` + 1 | from typing import overload + 2 | + 3 | class A: ... + 4 | class B: ... + 5 | class C: ... + 6 | + 7 | @overload + 8 | def f() -> None: ... + 9 | @overload +10 | def f(**kwargs: int) -> C: ... +11 | @overload +12 | def f(x: A, /, **kwargs: int) -> A: ... +13 | @overload +14 | def f(x: B, /, **kwargs: int) -> B: ... +``` + +## mdtest_snippet.py + +``` + 1 | from overloaded import A, B, f + 2 | from typing_extensions import reveal_type + 3 | + 4 | def _(a: int | None): + 5 | reveal_type( + 6 | # error: [no-matching-overload] + 7 | # revealed: Unknown + 8 | f( + 9 | A(), +10 | a1=a, +11 | a2=a, +12 | a3=a, +13 | a4=a, +14 | a5=a, +15 | a6=a, +16 | a7=a, +17 | a8=a, +18 | a9=a, +19 | a10=a, +20 | a11=a, +21 | a12=a, +22 | a13=a, +23 | a14=a, +24 | a15=a, +25 | a16=a, +26 | a17=a, +27 | a18=a, +28 | a19=a, +29 | a20=a, +30 | a21=a, +31 | a22=a, +32 | a23=a, +33 | a24=a, +34 | a25=a, +35 | a26=a, +36 | a27=a, +37 | a28=a, +38 | a29=a, +39 | a30=a, +40 | ) +41 | ) +``` + +# Diagnostics + +``` +error[no-matching-overload]: No overload of function `f` matches arguments + --> src/mdtest_snippet.py:8:9 + | + 6 | # error: [no-matching-overload] + 7 | # revealed: Unknown + 8 | / f( + 9 | | A(), +10 | | a1=a, +11 | | a2=a, +12 | | a3=a, +13 | | a4=a, +14 | | a5=a, +15 | | a6=a, +16 | | a7=a, +17 | | a8=a, +18 | | a9=a, +19 | | a10=a, +20 | | a11=a, +21 | | a12=a, +22 | | a13=a, +23 | | a14=a, +24 | | a15=a, +25 | | a16=a, +26 | | a17=a, +27 | | a18=a, +28 | | a19=a, +29 | | a20=a, +30 | | a21=a, +31 | | a22=a, +32 | | a23=a, +33 | | a24=a, +34 | | a25=a, +35 | | a26=a, +36 | | a27=a, +37 | | a28=a, +38 | | a29=a, +39 | | a30=a, +40 | | ) + | |_________^ +41 | ) + | +info: Limit of argument type expansion reached at argument 10 +info: First overload defined here + --> src/overloaded.pyi:8:5 + | + 7 | @overload + 8 | def f() -> None: ... + | ^^^^^^^^^^^ + 9 | @overload +10 | def f(**kwargs: int) -> C: ... + | +info: Possible overloads for function `f`: +info: () -> None +info: (**kwargs: int) -> C +info: (x: A, /, **kwargs: int) -> A +info: (x: B, /, **kwargs: int) -> B +info: rule `no-matching-overload` is enabled by default + +``` + +``` +info[revealed-type]: Revealed type + --> src/mdtest_snippet.py:8:9 + | + 6 | # error: [no-matching-overload] + 7 | # revealed: Unknown + 8 | / f( + 9 | | A(), +10 | | a1=a, +11 | | a2=a, +12 | | a3=a, +13 | | a4=a, +14 | | a5=a, +15 | | a6=a, +16 | | a7=a, +17 | | a8=a, +18 | | a9=a, +19 | | a10=a, +20 | | a11=a, +21 | | a12=a, +22 | | a13=a, +23 | | a14=a, +24 | | a15=a, +25 | | a16=a, +26 | | a17=a, +27 | | a18=a, +28 | | a19=a, +29 | | a20=a, +30 | | a21=a, +31 | | a22=a, +32 | | a23=a, +33 | | a24=a, +34 | | a25=a, +35 | | a26=a, +36 | | a27=a, +37 | | a28=a, +38 | | a29=a, +39 | | a30=a, +40 | | ) + | |_________^ `Unknown` +41 | ) + | + +``` diff --git a/crates/ty_python_semantic/src/types/call/arguments.rs b/crates/ty_python_semantic/src/types/call/arguments.rs index 648b0cdab9..3bf78c6a75 100644 --- a/crates/ty_python_semantic/src/types/call/arguments.rs +++ b/crates/ty_python_semantic/src/types/call/arguments.rs @@ -127,31 +127,39 @@ impl<'a, 'db> CallArguments<'a, 'db> { /// contains the same arguments, but with one or more of the argument types expanded. /// /// [argument type expansion]: https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion - pub(crate) fn expand( - &self, - db: &'db dyn Db, - ) -> impl Iterator>> + '_ { + pub(super) fn expand(&self, db: &'db dyn Db) -> impl Iterator> + '_ { + /// Maximum number of argument lists that can be generated in a single expansion step. + static MAX_EXPANSIONS: usize = 512; + /// Represents the state of the expansion process. + enum State<'a, 'b, 'db> { + LimitReached(usize), + Expanding(ExpandingState<'a, 'b, 'db>), + } + + /// Represents the expanding state with either the initial types or the expanded types. /// /// This is useful to avoid cloning the initial types vector if none of the types can be /// expanded. - enum State<'a, 'b, 'db> { + enum ExpandingState<'a, 'b, 'db> { Initial(&'b Vec>>), Expanded(Vec>), } - impl<'db> State<'_, '_, 'db> { + impl<'db> ExpandingState<'_, '_, 'db> { fn len(&self) -> usize { match self { - State::Initial(_) => 1, - State::Expanded(expanded) => expanded.len(), + ExpandingState::Initial(_) => 1, + ExpandingState::Expanded(expanded) => expanded.len(), } } fn iter(&self) -> impl Iterator>]> + '_ { match self { - State::Initial(types) => Either::Left(std::iter::once(types.as_slice())), - State::Expanded(expanded) => { + ExpandingState::Initial(types) => { + Either::Left(std::iter::once(types.as_slice())) + } + ExpandingState::Expanded(expanded) => { Either::Right(expanded.iter().map(CallArguments::types)) } } @@ -160,44 +168,82 @@ impl<'a, 'db> CallArguments<'a, 'db> { let mut index = 0; - std::iter::successors(Some(State::Initial(&self.types)), move |previous| { - // Find the next type that can be expanded. - let expanded_types = loop { - let arg_type = self.types.get(index)?; - if let Some(arg_type) = arg_type { - if let Some(expanded_types) = expand_type(db, *arg_type) { - break expanded_types; + std::iter::successors( + Some(State::Expanding(ExpandingState::Initial(&self.types))), + move |previous| { + let state = match previous { + State::LimitReached(index) => return Some(State::LimitReached(*index)), + State::Expanding(expanding_state) => expanding_state, + }; + + // Find the next type that can be expanded. + let expanded_types = loop { + let arg_type = self.types.get(index)?; + if let Some(arg_type) = arg_type { + if let Some(expanded_types) = expand_type(db, *arg_type) { + break expanded_types; + } + } + index += 1; + }; + + let expansion_size = expanded_types.len() * state.len(); + if expansion_size > MAX_EXPANSIONS { + tracing::debug!( + "Skipping argument type expansion as it would exceed the \ + maximum number of expansions ({MAX_EXPANSIONS})" + ); + return Some(State::LimitReached(index)); + } + + let mut expanded_arguments = Vec::with_capacity(expansion_size); + + for pre_expanded_types in state.iter() { + for subtype in &expanded_types { + let mut new_expanded_types = pre_expanded_types.to_vec(); + new_expanded_types[index] = Some(*subtype); + expanded_arguments.push(CallArguments::new( + self.arguments.clone(), + new_expanded_types, + )); } } + + // Increment the index to move to the next argument type for the next iteration. index += 1; - }; - let mut expanded_arguments = Vec::with_capacity(expanded_types.len() * previous.len()); - - for pre_expanded_types in previous.iter() { - for subtype in &expanded_types { - let mut new_expanded_types = pre_expanded_types.to_vec(); - new_expanded_types[index] = Some(*subtype); - expanded_arguments.push(CallArguments::new( - self.arguments.clone(), - new_expanded_types, - )); - } - } - - // Increment the index to move to the next argument type for the next iteration. - index += 1; - - Some(State::Expanded(expanded_arguments)) - }) + Some(State::Expanding(ExpandingState::Expanded( + expanded_arguments, + ))) + }, + ) .skip(1) // Skip the initial state, which has no expanded types. .map(|state| match state { - State::Initial(_) => unreachable!("initial state should be skipped"), - State::Expanded(expanded) => expanded, + State::LimitReached(index) => Expansion::LimitReached(index), + State::Expanding(ExpandingState::Initial(_)) => { + unreachable!("initial state should be skipped") + } + State::Expanding(ExpandingState::Expanded(expanded)) => Expansion::Expanded(expanded), }) } } +/// Represents a single element of the expansion process for argument types for [`expand`]. +/// +/// [`expand`]: CallArguments::expand +pub(super) enum Expansion<'a, 'db> { + /// Indicates that the expansion process has reached the maximum number of argument lists + /// that can be generated in a single step. + /// + /// The contained `usize` is the index of the argument type which would have been expanded + /// next, if not for the limit. + LimitReached(usize), + + /// Contains the expanded argument lists, where each list contains the same arguments, but with + /// one or more of the argument types expanded. + Expanded(Vec>), +} + impl<'a, 'db> FromIterator<(Argument<'a>, Option>)> for CallArguments<'a, 'db> { fn from_iter(iter: T) -> Self where diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index d2d46a2baa..da1d962f99 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -16,7 +16,7 @@ use crate::Program; use crate::db::Db; use crate::dunder_all::dunder_all_names; use crate::place::{Boundness, Place}; -use crate::types::call::arguments::is_expandable_type; +use crate::types::call::arguments::{Expansion, is_expandable_type}; use crate::types::diagnostic::{ CALL_NON_CALLABLE, CONFLICTING_ARGUMENT_FORMS, INVALID_ARGUMENT_TYPE, MISSING_ARGUMENT, NO_MATCHING_OVERLOAD, PARAMETER_ALREADY_ASSIGNED, TOO_MANY_POSITIONAL_ARGUMENTS, @@ -1368,7 +1368,18 @@ impl<'db> CallableBinding<'db> { } } - for expanded_argument_lists in expansions { + for expansion in expansions { + let expanded_argument_lists = match expansion { + Expansion::LimitReached(index) => { + snapshotter.restore(self, post_evaluation_snapshot); + self.overload_call_return_type = Some( + OverloadCallReturnType::ArgumentTypeExpansionLimitReached(index), + ); + return; + } + Expansion::Expanded(argument_lists) => argument_lists, + }; + // This is the merged state of the bindings after evaluating all of the expanded // argument lists. This will be the final state to restore the bindings to if all of // the expanded argument lists evaluated successfully. @@ -1667,7 +1678,8 @@ impl<'db> CallableBinding<'db> { if let Some(overload_call_return_type) = self.overload_call_return_type { return match overload_call_return_type { OverloadCallReturnType::ArgumentTypeExpansion(return_type) => return_type, - OverloadCallReturnType::Ambiguous => Type::unknown(), + OverloadCallReturnType::ArgumentTypeExpansionLimitReached(_) + | OverloadCallReturnType::Ambiguous => Type::unknown(), }; } if let Some((_, first_overload)) = self.matching_overloads().next() { @@ -1778,6 +1790,23 @@ impl<'db> CallableBinding<'db> { String::new() } )); + + if let Some(index) = + self.overload_call_return_type + .and_then( + |overload_call_return_type| match overload_call_return_type { + OverloadCallReturnType::ArgumentTypeExpansionLimitReached( + index, + ) => Some(index), + _ => None, + }, + ) + { + diag.info(format_args!( + "Limit of argument type expansion reached at argument {index}" + )); + } + if let Some((kind, function)) = function_type_and_kind { let (overloads, implementation) = function.overloads_and_implementation(context.db()); @@ -1844,6 +1873,7 @@ impl<'a, 'db> IntoIterator for &'a CallableBinding<'db> { #[derive(Debug, Copy, Clone)] enum OverloadCallReturnType<'db> { ArgumentTypeExpansion(Type<'db>), + ArgumentTypeExpansionLimitReached(usize), Ambiguous, }