[red-knot] Fix type inference for except* definitions (#13320)

This commit is contained in:
Alex Waygood 2024-09-11 15:05:40 -04:00 committed by GitHub
parent b72d49be16
commit 4dc2c257ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 179 additions and 51 deletions

View file

@ -27,7 +27,9 @@ use crate::semantic_index::SemanticIndex;
use crate::Db;
use super::constraint::{Constraint, PatternConstraint};
use super::definition::{MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef};
use super::definition::{
ExceptHandlerDefinitionNodeRef, MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef,
};
pub(super) struct SemanticIndexBuilder<'db> {
// Builder state
@ -696,6 +698,51 @@ where
self.flow_merge(after_subject);
}
}
ast::Stmt::Try(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
is_star,
range: _,
}) => {
self.visit_body(body);
for except_handler in handlers {
let ast::ExceptHandler::ExceptHandler(except_handler) = except_handler;
let ast::ExceptHandlerExceptHandler {
name: symbol_name,
type_: handled_exceptions,
body: handler_body,
range: _,
} = except_handler;
if let Some(handled_exceptions) = handled_exceptions {
self.visit_expr(handled_exceptions);
}
// If `handled_exceptions` above was `None`, it's something like `except as e:`,
// which is invalid syntax. However, it's still pretty obvious here that the user
// *wanted* `e` to be bound, so we should still create a definition here nonetheless.
if let Some(symbol_name) = symbol_name {
let symbol = self
.add_or_update_symbol(symbol_name.id.clone(), SymbolFlags::IS_DEFINED);
self.add_definition(
symbol,
DefinitionNodeRef::ExceptHandler(ExceptHandlerDefinitionNodeRef {
handler: except_handler,
is_star: *is_star,
}),
);
}
self.visit_body(handler_body);
}
self.visit_body(orelse);
self.visit_body(finalbody);
}
_ => {
walk_stmt(self, stmt);
}
@ -958,30 +1005,6 @@ where
self.current_match_case.as_mut().unwrap().index += 1;
}
fn visit_except_handler(&mut self, except_handler: &'ast ast::ExceptHandler) {
let ast::ExceptHandler::ExceptHandler(except_handler) = except_handler;
let ast::ExceptHandlerExceptHandler {
name: symbol_name,
type_: handled_exceptions,
body,
range: _,
} = except_handler;
if let Some(handled_exceptions) = handled_exceptions {
self.visit_expr(handled_exceptions);
}
// If `handled_exceptions` above was `None`, it's something like `except as e:`,
// which is invalid syntax. However, it's still pretty obvious here that the user
// *wanted* `e` to be bound, so we should still create a definition here nonetheless.
if let Some(symbol_name) = symbol_name {
let symbol = self.add_or_update_symbol(symbol_name.id.clone(), SymbolFlags::IS_DEFINED);
self.add_definition(symbol, except_handler);
}
self.visit_body(body);
}
}
#[derive(Copy, Clone, Debug)]

View file

@ -50,7 +50,7 @@ pub(crate) enum DefinitionNodeRef<'a> {
Parameter(ast::AnyParameterRef<'a>),
WithItem(WithItemDefinitionNodeRef<'a>),
MatchPattern(MatchPatternDefinitionNodeRef<'a>),
ExceptHandler(&'a ast::ExceptHandlerExceptHandler),
ExceptHandler(ExceptHandlerDefinitionNodeRef<'a>),
}
impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> {
@ -131,12 +131,6 @@ impl<'a> From<MatchPatternDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
}
}
impl<'a> From<&'a ast::ExceptHandlerExceptHandler> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::ExceptHandlerExceptHandler) -> Self {
Self::ExceptHandler(node)
}
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct ImportFromDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::StmtImportFrom,
@ -162,6 +156,12 @@ pub(crate) struct ForStmtDefinitionNodeRef<'a> {
pub(crate) is_async: bool,
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct ExceptHandlerDefinitionNodeRef<'a> {
pub(crate) handler: &'a ast::ExceptHandlerExceptHandler,
pub(crate) is_star: bool,
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
pub(crate) iterable: &'a ast::Expr,
@ -258,9 +258,13 @@ impl DefinitionNodeRef<'_> {
identifier: AstNodeRef::new(parsed, identifier),
index,
}),
DefinitionNodeRef::ExceptHandler(handler) => {
DefinitionKind::ExceptHandler(AstNodeRef::new(parsed, handler))
}
DefinitionNodeRef::ExceptHandler(ExceptHandlerDefinitionNodeRef {
handler,
is_star,
}) => DefinitionKind::ExceptHandler(ExceptHandlerDefinitionKind {
handler: AstNodeRef::new(parsed.clone(), handler),
is_star,
}),
}
}
@ -293,7 +297,7 @@ impl DefinitionNodeRef<'_> {
Self::MatchPattern(MatchPatternDefinitionNodeRef { identifier, .. }) => {
identifier.into()
}
Self::ExceptHandler(handler) => handler.into(),
Self::ExceptHandler(ExceptHandlerDefinitionNodeRef { handler, .. }) => handler.into(),
}
}
}
@ -314,7 +318,7 @@ pub enum DefinitionKind {
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
WithItem(WithItemDefinitionKind),
MatchPattern(MatchPatternDefinitionKind),
ExceptHandler(AstNodeRef<ast::ExceptHandlerExceptHandler>),
ExceptHandler(ExceptHandlerDefinitionKind),
}
#[derive(Clone, Debug)]
@ -430,6 +434,22 @@ impl ForStmtDefinitionKind {
}
}
#[derive(Clone, Debug)]
pub struct ExceptHandlerDefinitionKind {
handler: AstNodeRef<ast::ExceptHandlerExceptHandler>,
is_star: bool,
}
impl ExceptHandlerDefinitionKind {
pub(crate) fn handled_exceptions(&self) -> Option<&ast::Expr> {
self.handler.node().type_.as_deref()
}
pub(crate) fn is_star(&self) -> bool {
self.is_star
}
}
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub(crate) struct DefinitionNodeKey(NodeKey);

View file

@ -40,7 +40,9 @@ use ruff_text_size::Ranged;
use crate::module_name::ModuleName;
use crate::module_resolver::{file_to_module, resolve_module};
use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId};
use crate::semantic_index::definition::{Definition, DefinitionKind, DefinitionNodeKey};
use crate::semantic_index::definition::{
Definition, DefinitionKind, DefinitionNodeKey, ExceptHandlerDefinitionKind,
};
use crate::semantic_index::expression::Expression;
use crate::semantic_index::semantic_index;
use crate::semantic_index::symbol::{NodeWithScopeKind, NodeWithScopeRef, ScopeId};
@ -426,8 +428,8 @@ impl<'db> TypeInferenceBuilder<'db> {
definition,
);
}
DefinitionKind::ExceptHandler(handler) => {
self.infer_except_handler_definition(handler, definition);
DefinitionKind::ExceptHandler(except_handler_definition) => {
self.infer_except_handler_definition(except_handler_definition, definition);
}
}
}
@ -821,22 +823,29 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_except_handler_definition(
&mut self,
handler: &'db ast::ExceptHandlerExceptHandler,
except_handler_definition: &ExceptHandlerDefinitionKind,
definition: Definition<'db>,
) {
let node_ty = handler
.type_
.as_deref()
let node_ty = except_handler_definition
.handled_exceptions()
.map(|ty| self.infer_expression(ty))
.unwrap_or(Type::Unknown);
// TODO: anything that's a consistent subtype of
// `type[BaseException] | tuple[type[BaseException], ...]` should be valid;
// anything else should be invalid --Alex
let symbol_ty = match node_ty {
Type::Any | Type::Unknown => node_ty,
Type::Class(class_ty) => Type::Instance(class_ty),
_ => Type::Unknown,
let symbol_ty = if except_handler_definition.is_star() {
// TODO should be generic --Alex
//
// TODO should infer `ExceptionGroup` if all caught exceptions
// are subclasses of `Exception` --Alex
builtins_symbol_ty(self.db, "BaseExceptionGroup").to_instance(self.db)
} else {
// TODO: anything that's a consistent subtype of
// `type[BaseException] | tuple[type[BaseException], ...]` should be valid;
// anything else should be invalid --Alex
match node_ty {
Type::Any | Type::Unknown => node_ty,
Type::Class(class_ty) => Type::Instance(class_ty),
_ => Type::Unknown,
}
};
self.types.definitions.insert(definition, symbol_ty);
@ -4563,6 +4572,82 @@ mod tests {
Ok(())
}
#[test]
fn except_star_handler_baseexception() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
try:
x
except* BaseException as e:
pass
",
)?;
assert_file_diagnostics(&db, "src/a.py", &[]);
// TODO: once we support `sys.version_info` branches,
// we can set `--target-version=py311` in this test
// and the inferred type will just be `BaseExceptionGroup` --Alex
assert_public_ty(&db, "src/a.py", "e", "Unknown | BaseExceptionGroup");
Ok(())
}
#[test]
fn except_star_handler() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
try:
x
except* OSError as e:
pass
",
)?;
assert_file_diagnostics(&db, "src/a.py", &[]);
// TODO: once we support `sys.version_info` branches,
// we can set `--target-version=py311` in this test
// and the inferred type will just be `BaseExceptionGroup` --Alex
//
// TODO more precise would be `ExceptionGroup[OSError]` --Alex
assert_public_ty(&db, "src/a.py", "e", "Unknown | BaseExceptionGroup");
Ok(())
}
#[test]
fn except_star_handler_multiple_types() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
try:
x
except* (TypeError, AttributeError) as e:
pass
",
)?;
assert_file_diagnostics(&db, "src/a.py", &[]);
// TODO: once we support `sys.version_info` branches,
// we can set `--target-version=py311` in this test
// and the inferred type will just be `BaseExceptionGroup` --Alex
//
// TODO more precise would be `ExceptionGroup[TypeError | AttributeError]` --Alex
assert_public_ty(&db, "src/a.py", "e", "Unknown | BaseExceptionGroup");
Ok(())
}
#[test]
fn basic_comprehension() -> anyhow::Result<()> {
let mut db = setup_db();