[ty] Fix panics when pulling types for various special forms that have the wrong number of parameters (#18642)

This commit is contained in:
Alex Waygood 2025-06-17 10:40:50 +01:00 committed by GitHub
parent 342b2665db
commit 1d458d4314
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 151 additions and 78 deletions

View file

@ -139,8 +139,6 @@ x: int = MagicMock()
## Invalid
<!-- pull-types:skip -->
`Any` cannot be parameterized:
```py

View file

@ -14,8 +14,6 @@ directly.
### Negation
<!-- pull-types:skip -->
```py
from typing import Literal
from ty_extensions import Not, static_assert
@ -25,8 +23,12 @@ def negate(n1: Not[int], n2: Not[Not[int]], n3: Not[Not[Not[int]]]) -> None:
reveal_type(n2) # revealed: int
reveal_type(n3) # revealed: ~int
# error: "Special form `ty_extensions.Not` expected exactly one type parameter"
# error: "Special form `ty_extensions.Not` expected exactly 1 type argument, got 2"
n: Not[int, str]
# error: [invalid-type-form] "Special form `ty_extensions.Not` expected exactly 1 type argument, got 0"
o: Not[()]
p: Not[(int,)]
def static_truthiness(not_one: Not[Literal[1]]) -> None:
# TODO: `bool` is not incorrect, but these would ideally be `Literal[True]` and `Literal[False]`
@ -373,8 +375,6 @@ static_assert(not is_single_valued(Literal["a"] | Literal["b"]))
## `TypeOf`
<!-- pull-types:skip -->
We use `TypeOf` to get the inferred type of an expression. This is useful when we want to refer to
it in a type expression. For example, if we want to make sure that the class literal type `str` is a
subtype of `type[str]`, we can not use `is_subtype_of(str, type[str])`, as that would test if the
@ -400,13 +400,13 @@ class Derived(Base): ...
```py
def type_of_annotation() -> None:
t1: TypeOf[Base] = Base
t2: TypeOf[Base] = Derived # error: [invalid-assignment]
t2: TypeOf[(Base,)] = Derived # error: [invalid-assignment]
# Note how this is different from `type[…]` which includes subclasses:
s1: type[Base] = Base
s2: type[Base] = Derived # no error here
# error: "Special form `ty_extensions.TypeOf` expected exactly one type parameter"
# error: "Special form `ty_extensions.TypeOf` expected exactly 1 type argument, got 3"
t: TypeOf[int, str, bytes]
# error: [invalid-type-form] "`ty_extensions.TypeOf` requires exactly one argument when used in a type expression"
@ -416,8 +416,6 @@ def f(x: TypeOf) -> None:
## `CallableTypeOf`
<!-- pull-types:skip -->
The `CallableTypeOf` special form can be used to extract the `Callable` structural type inhabited by
a given callable object. This can be used to get the externally visibly signature of the object,
which can then be used to test various type properties.
@ -436,15 +434,23 @@ def f2() -> int:
def f3(x: int, y: str) -> None:
return
# error: [invalid-type-form] "Special form `ty_extensions.CallableTypeOf` expected exactly one type parameter"
# error: [invalid-type-form] "Special form `ty_extensions.CallableTypeOf` expected exactly 1 type argument, got 2"
c1: CallableTypeOf[f1, f2]
# error: [invalid-type-form] "Expected the first argument to `ty_extensions.CallableTypeOf` to be a callable object, but got an object of type `Literal["foo"]`"
c2: CallableTypeOf["foo"]
# error: [invalid-type-form] "Expected the first argument to `ty_extensions.CallableTypeOf` to be a callable object, but got an object of type `Literal["foo"]`"
c20: CallableTypeOf[("foo",)]
# error: [invalid-type-form] "`ty_extensions.CallableTypeOf` requires exactly one argument when used in a type expression"
def f(x: CallableTypeOf) -> None:
reveal_type(x) # revealed: Unknown
c3: CallableTypeOf[(f3,)]
# error: [invalid-type-form] "Special form `ty_extensions.CallableTypeOf` expected exactly 1 type argument, got 0"
c4: CallableTypeOf[()]
```
Using it in annotation to reveal the signature of the callable object:

View file

@ -1888,6 +1888,26 @@ pub(crate) fn report_invalid_arguments_to_annotated(
));
}
pub(crate) fn report_invalid_argument_number_to_special_form(
context: &InferContext,
subscript: &ast::ExprSubscript,
special_form: SpecialFormType,
received_arguments: usize,
expected_arguments: u8,
) {
let noun = if expected_arguments == 1 {
"type argument"
} else {
"type arguments"
};
if let Some(builder) = context.report_lint(&INVALID_TYPE_FORM, subscript) {
builder.into_diagnostic(format_args!(
"Special form `{special_form}` expected exactly {expected_arguments} {noun}, \
got {received_arguments}",
));
}
}
pub(crate) fn report_bad_argument_to_get_protocol_members(
context: &InferContext,
call: &ast::ExprCall,

View file

@ -84,10 +84,10 @@ use crate::types::diagnostic::{
INVALID_TYPE_VARIABLE_CONSTRAINTS, POSSIBLY_UNBOUND_IMPLICIT_CALL, POSSIBLY_UNBOUND_IMPORT,
TypeCheckDiagnostics, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_IMPORT,
UNRESOLVED_REFERENCE, UNSUPPORTED_OPERATOR, report_implicit_return_type,
report_invalid_arguments_to_annotated, report_invalid_arguments_to_callable,
report_invalid_assignment, report_invalid_attribute_assignment,
report_invalid_generator_function_return_type, report_invalid_return_type,
report_possibly_unbound_attribute,
report_invalid_argument_number_to_special_form, report_invalid_arguments_to_annotated,
report_invalid_arguments_to_callable, report_invalid_assignment,
report_invalid_attribute_assignment, report_invalid_generator_function_return_type,
report_invalid_return_type, report_possibly_unbound_attribute,
};
use crate::types::function::{
FunctionDecorators, FunctionLiteral, FunctionType, KnownFunction, OverloadLiteral,
@ -9329,6 +9329,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
Type::unknown()
}
Type::ClassLiteral(literal) if literal.is_known(self.db(), KnownClass::Any) => {
self.infer_expression(slice);
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) {
builder.into_diagnostic("Type `typing.Any` expected no type parameter");
}
@ -9558,20 +9559,33 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
// Type API special forms
SpecialFormType::Not => match arguments_slice {
ast::Expr::Tuple(_) => {
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) {
builder.into_diagnostic(format_args!(
"Special form `{special_form}` expected exactly one type parameter",
));
SpecialFormType::Not => {
let arguments = if let ast::Expr::Tuple(tuple) = arguments_slice {
&*tuple.elts
} else {
std::slice::from_ref(arguments_slice)
};
let num_arguments = arguments.len();
let negated_type = if num_arguments == 1 {
self.infer_type_expression(&arguments[0]).negate(db)
} else {
for argument in arguments {
self.infer_type_expression(argument);
}
report_invalid_argument_number_to_special_form(
&self.context,
subscript,
special_form,
num_arguments,
1,
);
Type::unknown()
};
if arguments_slice.is_tuple_expr() {
self.store_expression_type(arguments_slice, negated_type);
}
_ => {
let argument_type = self.infer_type_expression(arguments_slice);
argument_type.negate(db)
negated_type
}
},
SpecialFormType::Intersection => {
let elements = match arguments_slice {
ast::Expr::Tuple(tuple) => Either::Left(tuple.iter()),
@ -9589,32 +9603,61 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
ty
}
SpecialFormType::TypeOf => match arguments_slice {
ast::Expr::Tuple(_) => {
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) {
builder.into_diagnostic(format_args!(
"Special form `{special_form}` expected exactly one type parameter",
));
SpecialFormType::TypeOf => {
let arguments = if let ast::Expr::Tuple(tuple) = arguments_slice {
&*tuple.elts
} else {
std::slice::from_ref(arguments_slice)
};
let num_arguments = arguments.len();
let type_of_type = if num_arguments == 1 {
// N.B. This uses `infer_expression` rather than `infer_type_expression`
self.infer_expression(&arguments[0])
} else {
for argument in arguments {
self.infer_type_expression(argument);
}
report_invalid_argument_number_to_special_form(
&self.context,
subscript,
special_form,
num_arguments,
1,
);
Type::unknown()
};
if arguments_slice.is_tuple_expr() {
self.store_expression_type(arguments_slice, type_of_type);
}
type_of_type
}
_ => {
// NB: This calls `infer_expression` instead of `infer_type_expression`.
self.infer_expression(arguments_slice)
SpecialFormType::CallableTypeOf => {
let arguments = if let ast::Expr::Tuple(tuple) = arguments_slice {
&*tuple.elts
} else {
std::slice::from_ref(arguments_slice)
};
let num_arguments = arguments.len();
if num_arguments != 1 {
for argument in arguments {
self.infer_expression(argument);
}
},
SpecialFormType::CallableTypeOf => match arguments_slice {
ast::Expr::Tuple(_) => {
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) {
builder.into_diagnostic(format_args!(
"Special form `{special_form}` expected exactly one type parameter",
));
report_invalid_argument_number_to_special_form(
&self.context,
subscript,
special_form,
num_arguments,
1,
);
if arguments_slice.is_tuple_expr() {
self.store_expression_type(arguments_slice, Type::unknown());
}
Type::unknown()
return Type::unknown();
}
_ => {
let argument_type = self.infer_expression(arguments_slice);
let argument_type = self.infer_expression(&arguments[0]);
let bindings = argument_type.bindings(db);
// SAFETY: This is enforced by the constructor methods on `Bindings` even in
@ -9644,15 +9687,21 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
actual_type = argument_type.display(db)
));
}
if arguments_slice.is_tuple_expr() {
self.store_expression_type(arguments_slice, Type::unknown());
}
return Type::unknown();
};
let signature = CallableSignature::from_overloads(
std::iter::once(signature).chain(signature_iter),
);
Type::Callable(CallableType::new(db, signature, false))
let callable_type_of = Type::Callable(CallableType::new(db, signature, false));
if arguments_slice.is_tuple_expr() {
self.store_expression_type(arguments_slice, callable_type_of);
}
callable_type_of
}
},
SpecialFormType::ChainMap => self.infer_parameterized_legacy_typing_alias(
subscript,