[red-knot] Improve type inference for except handlers (#14838)

This commit is contained in:
Alex Waygood 2024-12-09 22:49:58 +00:00 committed by GitHub
parent 64944f2cf5
commit ab26d9cf9a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 150 additions and 41 deletions

View file

@ -49,12 +49,44 @@ def foo(
try:
help()
except x as e:
# TODO: should be `AttributeError`
reveal_type(e) # revealed: @Todo(exception type)
reveal_type(e) # revealed: AttributeError
except y as f:
# TODO: should be `OSError | RuntimeError`
reveal_type(f) # revealed: @Todo(exception type)
reveal_type(f) # revealed: OSError | RuntimeError
except z as g:
# TODO: should be `BaseException`
reveal_type(g) # revealed: @Todo(exception type)
reveal_type(g) # revealed: @Todo(full tuple[...] support)
```
## Invalid exception handlers
```py
try:
pass
# error: [invalid-exception] "Cannot catch object of type `Literal[3]` in an exception handler (must be a `BaseException` subclass or a tuple of `BaseException` subclasses)"
except 3 as e:
reveal_type(e) # revealed: Unknown
try:
pass
# error: [invalid-exception] "Cannot catch object of type `Literal["foo"]` in an exception handler (must be a `BaseException` subclass or a tuple of `BaseException` subclasses)"
# error: [invalid-exception] "Cannot catch object of type `Literal[b"bar"]` in an exception handler (must be a `BaseException` subclass or a tuple of `BaseException` subclasses)"
except (ValueError, OSError, "foo", b"bar") as e:
reveal_type(e) # revealed: ValueError | OSError | Unknown
def foo(
x: type[str],
y: tuple[type[OSError], type[RuntimeError], int],
z: tuple[type[str], ...],
):
try:
help()
# error: [invalid-exception]
except x as e:
reveal_type(e) # revealed: Unknown
# error: [invalid-exception]
except y as f:
reveal_type(f) # revealed: OSError | RuntimeError | Unknown
except z as g:
# TODO: should emit a diagnostic here:
reveal_type(g) # revealed: @Todo(full tuple[...] support)
```

View file

@ -1,30 +1,59 @@
# Except star
# `except*`
## Except\* with BaseException
## `except*` with `BaseException`
```py
try:
help()
except* BaseException as e:
# TODO: should be `BaseExceptionGroup[BaseException]` --Alex
reveal_type(e) # revealed: BaseExceptionGroup
```
## Except\* with specific exception
## `except*` with specific exception
```py
try:
help()
except* OSError as e:
# TODO(Alex): more precise would be `ExceptionGroup[OSError]`
# TODO: more precise would be `ExceptionGroup[OSError]` --Alex
# (needs homogenous tuples + generics)
reveal_type(e) # revealed: BaseExceptionGroup
```
## Except\* with multiple exceptions
## `except*` with multiple exceptions
```py
try:
help()
except* (TypeError, AttributeError) as e:
# TODO(Alex): more precise would be `ExceptionGroup[TypeError | AttributeError]`.
# TODO: more precise would be `ExceptionGroup[TypeError | AttributeError]` --Alex
# (needs homogenous tuples + generics)
reveal_type(e) # revealed: BaseExceptionGroup
```
## `except*` with mix of `Exception`s and `BaseException`s
```py
try:
help()
except* (KeyboardInterrupt, AttributeError) as e:
# TODO: more precise would be `BaseExceptionGroup[KeyboardInterrupt | AttributeError]` --Alex
reveal_type(e) # revealed: BaseExceptionGroup
```
## Invalid `except*` handlers
```py
try:
help()
except* 3 as e: # error: [invalid-exception]
# TODO: Should be `BaseExceptionGroup[Unknown]` --Alex
reveal_type(e) # revealed: BaseExceptionGroup
try:
help()
except* (AttributeError, 42) as e: # error: [invalid-exception]
# TODO: Should be `BaseExceptionGroup[AttributeError | Unknown]` --Alex
reveal_type(e) # revealed: BaseExceptionGroup
```

View file

@ -1173,6 +1173,8 @@ impl<'db> Type<'db> {
| KnownClass::Set
| KnownClass::Dict
| KnownClass::Slice
| KnownClass::BaseException
| KnownClass::BaseExceptionGroup
| KnownClass::GenericAlias
| KnownClass::ModuleType
| KnownClass::FunctionType
@ -1845,6 +1847,8 @@ pub enum KnownClass {
Set,
Dict,
Slice,
BaseException,
BaseExceptionGroup,
// Types
GenericAlias,
ModuleType,
@ -1875,6 +1879,8 @@ impl<'db> KnownClass {
Self::List => "list",
Self::Type => "type",
Self::Slice => "slice",
Self::BaseException => "BaseException",
Self::BaseExceptionGroup => "BaseExceptionGroup",
Self::GenericAlias => "GenericAlias",
Self::ModuleType => "ModuleType",
Self::FunctionType => "FunctionType",
@ -1902,6 +1908,12 @@ impl<'db> KnownClass {
.unwrap_or(Type::Unknown)
}
pub fn to_subclass_of(self, db: &'db dyn Db) -> Option<Type<'db>> {
self.to_class_literal(db)
.into_class_literal()
.map(|ClassLiteralType { class }| Type::subclass_of(class))
}
/// Return the module in which we should look up the definition for this class
pub(crate) fn canonical_module(self, db: &'db dyn Db) -> CoreStdlibModule {
match self {
@ -1916,6 +1928,8 @@ impl<'db> KnownClass {
| Self::Tuple
| Self::Set
| Self::Dict
| Self::BaseException
| Self::BaseExceptionGroup
| Self::Slice => CoreStdlibModule::Builtins,
Self::VersionInfo => CoreStdlibModule::Sys,
Self::GenericAlias | Self::ModuleType | Self::FunctionType => CoreStdlibModule::Types,
@ -1959,6 +1973,8 @@ impl<'db> KnownClass {
| Self::ModuleType
| Self::FunctionType
| Self::SpecialForm
| Self::BaseException
| Self::BaseExceptionGroup
| Self::TypeVar => false,
}
}
@ -1980,6 +1996,8 @@ impl<'db> KnownClass {
"dict" => Self::Dict,
"list" => Self::List,
"slice" => Self::Slice,
"BaseException" => Self::BaseException,
"BaseExceptionGroup" => Self::BaseExceptionGroup,
"GenericAlias" => Self::GenericAlias,
"NoneType" => Self::NoneType,
"ModuleType" => Self::ModuleType,
@ -2016,6 +2034,8 @@ impl<'db> KnownClass {
| Self::GenericAlias
| Self::ModuleType
| Self::VersionInfo
| Self::BaseException
| Self::BaseExceptionGroup
| Self::FunctionType => module.name() == self.canonical_module(db).as_str(),
Self::NoneType => matches!(module.name().as_str(), "_typeshed" | "types"),
Self::SpecialForm | Self::TypeVar | Self::TypeAliasType | Self::NoDefaultType => {

View file

@ -289,6 +289,18 @@ impl<'db> TypeCheckDiagnosticsBuilder<'db> {
);
}
pub(super) fn add_invalid_exception(&mut self, db: &dyn Db, node: &ast::Expr, ty: Type) {
self.add(
node.into(),
"invalid-exception",
format_args!(
"Cannot catch object of type `{}` in an exception handler \
(must be a `BaseException` subclass or a tuple of `BaseException` subclasses)",
ty.display(db)
),
);
}
/// Adds a new diagnostic.
///
/// The diagnostic does not get added if the rule isn't enabled for this file.

View file

@ -1535,40 +1535,56 @@ impl<'db> TypeInferenceBuilder<'db> {
except_handler_definition: &ExceptHandlerDefinitionKind,
definition: Definition<'db>,
) {
let node_ty = except_handler_definition
.handled_exceptions()
.map(|ty| self.infer_expression(ty))
// If there is no handled exception, it's invalid syntax;
// a diagnostic will have already been emitted
.unwrap_or(Type::Unknown);
let node = except_handler_definition.handled_exceptions();
// If there is no handled exception, it's invalid syntax;
// a diagnostic will have already been emitted
let node_ty = node.map_or(Type::Unknown, |ty| self.infer_expression(ty));
// If it's an `except*` handler, this won't actually be the type of the bound symbol;
// it will actually be the type of the generic parameters to `BaseExceptionGroup` or `ExceptionGroup`.
let symbol_ty = if let Type::Tuple(tuple) = node_ty {
let type_base_exception = KnownClass::BaseException
.to_subclass_of(self.db)
.unwrap_or(Type::Unknown);
let mut builder = UnionBuilder::new(self.db);
for element in tuple.elements(self.db).iter().copied() {
builder = builder.add(if element.is_assignable_to(self.db, type_base_exception) {
element.to_instance(self.db)
} else {
if let Some(node) = node {
self.diagnostics
.add_invalid_exception(self.db, node, element);
}
Type::Unknown
});
}
builder.build()
} else if node_ty.is_subtype_of(self.db, KnownClass::Tuple.to_instance(self.db)) {
todo_type!("Homogeneous tuple in exception handler")
} else {
let type_base_exception = KnownClass::BaseException
.to_subclass_of(self.db)
.unwrap_or(Type::Unknown);
if node_ty.is_assignable_to(self.db, type_base_exception) {
node_ty.to_instance(self.db)
} else {
if let Some(node) = node {
self.diagnostics
.add_invalid_exception(self.db, node, node_ty);
}
Type::Unknown
}
};
let symbol_ty = if except_handler_definition.is_star() {
// TODO should be generic --Alex
// TODO: we should infer `ExceptionGroup` if `node_ty` is a subtype of `tuple[type[Exception], ...]`
// (needs support for homogeneous tuples).
//
// TODO should infer `ExceptionGroup` if all caught exceptions
// are subclasses of `Exception` --Alex
builtins_symbol(self.db, "BaseExceptionGroup")
.ignore_possibly_unbound()
.unwrap_or(Type::Unknown)
.to_instance(self.db)
// TODO: should be generic with `symbol_ty` as the generic parameter
KnownClass::BaseExceptionGroup.to_instance(self.db)
} else {
// TODO: anything that's a consistent subtype of
// `type[BaseException] | tuple[type[BaseException], ...]` should be valid;
// anything else is invalid and should lead to a diagnostic being reported --Alex
match node_ty {
Type::Any | Type::Unknown => node_ty,
Type::ClassLiteral(ClassLiteralType { class }) => Type::instance(class),
Type::Tuple(tuple) => UnionType::from_elements(
self.db,
tuple.elements(self.db).iter().map(|ty| {
ty.into_class_literal().map_or(
todo_type!("exception type"),
|ClassLiteralType { class }| Type::instance(class),
)
}),
),
_ => todo_type!("exception type"),
}
symbol_ty
};
self.add_binding(