From a3f3d734a19eebc12d216b5fa16515468e46253d Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Wed, 19 Mar 2025 14:42:42 +0000 Subject: [PATCH] [red-knot] Ban list literals in most contexts in type expressions (#16847) ## Summary This PR reworks `TypeInferenceBuilder::infer_type_expression()` so that we emit diagnostics when encountering a list literal in a type expression. The only place where a list literal is allowed in a type expression is if it appears as the first argument to `Callable[]`, and `Callable` is already heavily special-cased in our type-expression parsing. In order to ensure that list literals are _always_ allowed as the _first_ argument to `Callabler` (but never allowed as the second, third, etc. argument), I had to do some refactoring of our type-expression parsing for `Callable` annotations. ## Test Plan New mdtests added, and existing ones updated --- .../resources/mdtest/annotations/callable.md | 56 +++++++++- .../resources/mdtest/annotations/invalid.md | 2 + .../src/types/infer.rs | 103 +++++++++--------- 3 files changed, 111 insertions(+), 50 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/annotations/callable.md b/crates/red_knot_python_semantic/resources/mdtest/annotations/callable.md index 37ea251e37..b84bfe41a5 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/annotations/callable.md +++ b/crates/red_knot_python_semantic/resources/mdtest/annotations/callable.md @@ -63,7 +63,7 @@ from typing import Callable # error: [invalid-type-form] "Special form `typing.Callable` expected exactly two arguments (parameter types and return type)" def _(c: Callable[[int, str]]): - reveal_type(c) # revealed: (int, str, /) -> Unknown + reveal_type(c) # revealed: (...) -> Unknown ``` Or, an ellipsis: @@ -74,6 +74,18 @@ def _(c: Callable[...]): reveal_type(c) # revealed: (...) -> Unknown ``` +Or something else that's invalid in a type expression generally: + +```py +# fmt: off + +def _(c: Callable[ # error: [invalid-type-form] "Special form `typing.Callable` expected exactly two arguments (parameter types and return type)" + {1, 2} # error: [invalid-type-form] "The first argument to `Callable` must be either a list of types, ParamSpec, Concatenate, or `...`" + ] + ): + reveal_type(c) # revealed: (...) -> Unknown +``` + ### More than two arguments We can't reliably infer the callable type if there are more then 2 arguments because we don't know @@ -87,6 +99,48 @@ def _(c: Callable[[int], str, str]): reveal_type(c) # revealed: (...) -> Unknown ``` +### List as the second argument + +```py +from typing import Callable + +# fmt: off + +def _(c: Callable[ + int, # error: [invalid-type-form] "The first argument to `Callable` must be either a list of types, ParamSpec, Concatenate, or `...`" + [str] # error: [invalid-type-form] "List literals are not allowed in this context in a type expression" + ] + ): + reveal_type(c) # revealed: (...) -> Unknown +``` + +### List as both arguments + +```py +from typing import Callable + +# error: [invalid-type-form] "List literals are not allowed in this context in a type expression" +def _(c: Callable[[int], [str]]): + reveal_type(c) # revealed: (int, /) -> Unknown +``` + +### Three list arguments + +```py +from typing import Callable + +# fmt: off + + +def _(c: Callable[ # error: [invalid-type-form] "Special form `typing.Callable` expected exactly two arguments (parameter types and return type)" + [int], + [str], # error: [invalid-type-form] "List literals are not allowed in this context in a type expression" + [bytes] # error: [invalid-type-form] "List literals are not allowed in this context in a type expression" + ] + ): + reveal_type(c) # revealed: (...) -> Unknown +``` + ## Simple A simple `Callable` with multiple parameters and a return type: diff --git a/crates/red_knot_python_semantic/resources/mdtest/annotations/invalid.md b/crates/red_knot_python_semantic/resources/mdtest/annotations/invalid.md index 7243dc4dc7..5daa0b6249 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/annotations/invalid.md +++ b/crates/red_knot_python_semantic/resources/mdtest/annotations/invalid.md @@ -96,6 +96,7 @@ def _( d: [k for k in [1, 2]], # error: [invalid-type-form] "List comprehensions are not allowed in type expressions" e: {k for k in [1, 2]}, # error: [invalid-type-form] "Set comprehensions are not allowed in type expressions" f: (k for k in [1, 2]), # error: [invalid-type-form] "Generator expressions are not allowed in type expressions" + g: [int, str], # error: [invalid-type-form] "List literals are not allowed in this context in a type expression" ): reveal_type(a) # revealed: Unknown reveal_type(b) # revealed: Unknown @@ -103,4 +104,5 @@ def _( reveal_type(d) # revealed: Unknown reveal_type(e) # revealed: Unknown reveal_type(f) # revealed: Unknown + reveal_type(g) # revealed: Unknown ``` diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 51ff34a81f..ba07aefb2c 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -6152,6 +6152,18 @@ impl<'db> TypeInferenceBuilder<'db> { ), ), + // TODO: add a subdiagnostic linking to type-expression grammar + // and stating that it is only valid as first argument to `typing.Callable[]` + ast::Expr::List(list) => { + self.infer_list_expression(list); + self.report_invalid_type_expression( + expression, + format_args!( + "List literals are not allowed in this context in a type expression" + ), + ) + } + ast::Expr::BoolOp(bool_op) => { self.infer_boolean_expression(bool_op); self.report_invalid_type_expression( @@ -6305,11 +6317,6 @@ impl<'db> TypeInferenceBuilder<'db> { todo_type!("ellipsis literal in type expression") } - ast::Expr::List(list) => { - self.infer_list_expression(list); - Type::unknown() - } - ast::Expr::Tuple(tuple) => { self.infer_tuple_expression(tuple); Type::unknown() @@ -6595,53 +6602,51 @@ impl<'db> TypeInferenceBuilder<'db> { todo_type!("Generic PEP-695 type alias") } KnownInstanceType::Callable => { - let ast::Expr::Tuple(ast::ExprTuple { - elts: arguments, .. - }) = arguments_slice - else { + let mut arguments = match arguments_slice { + ast::Expr::Tuple(tuple) => Either::Left(tuple.iter()), + _ => { + self.infer_callable_parameter_types(arguments_slice); + Either::Right(std::iter::empty::<&ast::Expr>()) + } + }; + + let first_argument = arguments.next(); + + let parameters = + first_argument.and_then(|arg| self.infer_callable_parameter_types(arg)); + + let return_type = arguments.next().map(|arg| self.infer_type_expression(arg)); + + let correct_argument_number = if let Some(third_argument) = arguments.next() { + self.infer_type_expression(third_argument); + for argument in arguments { + self.infer_type_expression(argument); + } + false + } else { + return_type.is_some() + }; + + if !correct_argument_number { report_invalid_arguments_to_callable(&self.context, subscript); + } - // If it's not a tuple, defer it to inferring the parameter types which could - // return an `Err` if the expression is invalid in that position. In which - // case, we'll fallback to using an unknown list of parameters. - let parameters = self - .infer_callable_parameter_types(arguments_slice) - .unwrap_or_else(|()| Parameters::unknown()); - - let callable_type = - Type::Callable(CallableType::General(GeneralCallableType::new( - db, - Signature::new(parameters, Some(Type::unknown())), - ))); - - // `Parameters` is not a `Type` variant, so we're storing the outer callable - // type on the arguments slice instead. - self.store_expression_type(arguments_slice, callable_type); - - return callable_type; + let callable_type = if let (Some(parameters), Some(return_type), true) = + (parameters, return_type, correct_argument_number) + { + GeneralCallableType::new(db, Signature::new(parameters, Some(return_type))) + } else { + GeneralCallableType::unknown(db) }; - let [first_argument, second_argument] = arguments.as_slice() else { - report_invalid_arguments_to_callable(&self.context, subscript); - self.infer_type_expression(arguments_slice); - return Type::Callable(CallableType::General(GeneralCallableType::unknown(db))); - }; - - let Ok(parameters) = self.infer_callable_parameter_types(first_argument) else { - self.infer_type_expression(arguments_slice); - return Type::Callable(CallableType::General(GeneralCallableType::unknown(db))); - }; - - let return_type = self.infer_type_expression(second_argument); - - let callable_type = Type::Callable(CallableType::General( - GeneralCallableType::new(db, Signature::new(parameters, Some(return_type))), - )); + let callable_type = Type::Callable(CallableType::General(callable_type)); // `Signature` / `Parameters` are not a `Type` variant, so we're storing // the outer callable type on the these expressions instead. self.store_expression_type(arguments_slice, callable_type); - self.store_expression_type(first_argument, callable_type); + if let Some(first_argument) = first_argument { + self.store_expression_type(first_argument, callable_type); + } callable_type } @@ -6932,13 +6937,13 @@ impl<'db> TypeInferenceBuilder<'db> { /// Infer the first argument to a `typing.Callable` type expression and returns the /// corresponding [`Parameters`]. /// - /// It returns an [`Err`] if the argument is invalid i.e., not a list of types, parameter + /// It returns `None` if the argument is invalid i.e., not a list of types, parameter /// specification, `typing.Concatenate`, or `...`. fn infer_callable_parameter_types( &mut self, parameters: &ast::Expr, - ) -> Result, ()> { - Ok(match parameters { + ) -> Option> { + Some(match parameters { ast::Expr::EllipsisLiteral(ast::ExprEllipsisLiteral { .. }) => { Parameters::gradual_form() } @@ -6982,7 +6987,7 @@ impl<'db> TypeInferenceBuilder<'db> { // This is a special case to avoid raising the error suggesting what the first // argument should be. This only happens when there's already a syntax error like // `Callable[]`. - return Err(()); + return None; } _ => { // TODO: Check whether `Expr::Name` is a ParamSpec @@ -6993,7 +6998,7 @@ impl<'db> TypeInferenceBuilder<'db> { "The first argument to `Callable` must be either a list of types, ParamSpec, Concatenate, or `...`", ), ); - return Err(()); + return None; } }) }