[red-knot] Handle context managers in (sync) with statements (#13998)

This commit is contained in:
Micha Reiser 2024-10-31 09:18:18 +01:00 committed by GitHub
parent 2d917d72f6
commit 76e4277696
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 315 additions and 33 deletions

View file

@ -0,0 +1,20 @@
# Async with statements
## Basic `async with` statement
The type of the target variable in a `with` statement should be the return type from the context manager's `__aenter__` method.
However, `async with` statements aren't supported yet. This test asserts that it doesn't emit any context manager-related errors.
```py
class Target: ...
class Manager:
async def __aenter__(self) -> Target:
return Target()
async def __aexit__(self, exc_type, exc_value, traceback): ...
async def test():
async with Manager() as f:
reveal_type(f) # revealed: @Todo
```

View file

@ -0,0 +1,140 @@
# With statements
## Basic `with` statement
The type of the target variable in a `with` statement is the return type from the context manager's `__enter__` method.
```py
class Target: ...
class Manager:
def __enter__(self) -> Target:
return Target()
def __exit__(self, exc_type, exc_value, traceback): ...
with Manager() as f:
reveal_type(f) # revealed: Target
```
## Union context manager
```py
def coinflip() -> bool:
return True
class Manager1:
def __enter__(self) -> str:
return "foo"
def __exit__(self, exc_type, exc_value, traceback): ...
class Manager2:
def __enter__(self) -> int:
return 42
def __exit__(self, exc_type, exc_value, traceback): ...
context_expr = Manager1() if coinflip() else Manager2()
with context_expr as f:
reveal_type(f) # revealed: str | int
```
## Context manager without an `__enter__` or `__exit__` method
```py
class Manager: ...
# error: [invalid-context-manager] "Object of type Manager cannot be used with `with` because it doesn't implement `__enter__` and `__exit__`"
with Manager():
...
```
## Context manager without an `__enter__` method
```py
class Manager:
def __exit__(self, exc_tpe, exc_value, traceback): ...
# error: [invalid-context-manager] "Object of type Manager cannot be used with `with` because it doesn't implement `__enter__`"
with Manager():
...
```
## Context manager without an `__exit__` method
```py
class Manager:
def __enter__(self): ...
# error: [invalid-context-manager] "Object of type Manager cannot be used with `with` because it doesn't implement `__exit__`"
with Manager():
...
```
## Context manager with non-callable `__enter__` attribute
```py
class Manager:
__enter__ = 42
def __exit__(self, exc_tpe, exc_value, traceback): ...
# error: [invalid-context-manager] "Object of type Manager cannot be used with `with` because the method `__enter__` of type Literal[42] is not callable"
with Manager():
...
```
## Context manager with non-callable `__exit__` attribute
```py
class Manager:
def __enter__(self) -> Self: ...
__exit__ = 32
# error: [invalid-context-manager] "Object of type Manager cannot be used with `with` because the method `__exit__` of type Literal[32] is not callable"
with Manager():
...
```
## Context expression with non-callable union variants
```py
def coinflip() -> bool:
return True
class Manager1:
def __enter__(self) -> str:
return "foo"
def __exit__(self, exc_type, exc_value, traceback): ...
class NotAContextManager: ...
context_expr = Manager1() if coinflip() else NotAContextManager()
# error: [invalid-context-manager] "Object of type Manager1 | NotAContextManager cannot be used with `with` because the method `__enter__` of type Literal[__enter__] | Unbound is not callable"
# error: [invalid-context-manager] "Object of type Manager1 | NotAContextManager cannot be used with `with` because the method `__exit__` of type Literal[__exit__] | Unbound is not callable"
with context_expr as f:
reveal_type(f) # revealed: str | Unknown
```
## Context expression with "sometimes" callable `__enter__` method
```py
def coinflip() -> bool:
return True
class Manager:
if coinflip():
def __enter__(self) -> str:
return "abcd"
def __exit__(self, *args): ...
with Manager() as f:
# TODO: This should emit an error that `__enter__` is possibly unbound.
reveal_type(f) # revealed: str
```

View file

@ -734,12 +734,20 @@ where
self.flow_merge(break_state);
}
}
ast::Stmt::With(ast::StmtWith { items, body, .. }) => {
ast::Stmt::With(ast::StmtWith {
items,
body,
is_async,
..
}) => {
for item in items {
self.visit_expr(&item.context_expr);
if let Some(optional_vars) = item.optional_vars.as_deref() {
self.add_standalone_expression(&item.context_expr);
self.push_assignment(item.into());
self.push_assignment(CurrentAssignment::WithItem {
item,
is_async: *is_async,
});
self.visit_expr(optional_vars);
self.pop_assignment();
}
@ -1011,12 +1019,13 @@ where
},
);
}
Some(CurrentAssignment::WithItem(with_item)) => {
Some(CurrentAssignment::WithItem { item, is_async }) => {
self.add_definition(
symbol,
WithItemDefinitionNodeRef {
node: with_item,
node: item,
target: name_node,
is_async,
},
);
}
@ -1232,7 +1241,10 @@ enum CurrentAssignment<'a> {
node: &'a ast::Comprehension,
first: bool,
},
WithItem(&'a ast::WithItem),
WithItem {
item: &'a ast::WithItem,
is_async: bool,
},
}
impl<'a> From<&'a ast::StmtAnnAssign> for CurrentAssignment<'a> {
@ -1259,12 +1271,6 @@ impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
}
}
impl<'a> From<&'a ast::WithItem> for CurrentAssignment<'a> {
fn from(value: &'a ast::WithItem) -> Self {
Self::WithItem(value)
}
}
struct CurrentMatchCase<'a> {
/// The pattern that's part of the current match case.
pattern: &'a ast::Pattern,

View file

@ -176,6 +176,7 @@ pub(crate) struct AssignmentDefinitionNodeRef<'a> {
pub(crate) struct WithItemDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::WithItem,
pub(crate) target: &'a ast::ExprName,
pub(crate) is_async: bool,
}
#[derive(Copy, Clone, Debug)]
@ -277,12 +278,15 @@ impl DefinitionNodeRef<'_> {
DefinitionKind::ParameterWithDefault(AstNodeRef::new(parsed, parameter))
}
},
DefinitionNodeRef::WithItem(WithItemDefinitionNodeRef { node, target }) => {
DefinitionKind::WithItem(WithItemDefinitionKind {
node: AstNodeRef::new(parsed.clone(), node),
target: AstNodeRef::new(parsed, target),
})
}
DefinitionNodeRef::WithItem(WithItemDefinitionNodeRef {
node,
target,
is_async,
}) => DefinitionKind::WithItem(WithItemDefinitionKind {
node: AstNodeRef::new(parsed.clone(), node),
target: AstNodeRef::new(parsed, target),
is_async,
}),
DefinitionNodeRef::MatchPattern(MatchPatternDefinitionNodeRef {
pattern,
identifier,
@ -329,7 +333,11 @@ impl DefinitionNodeRef<'_> {
ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
},
Self::WithItem(WithItemDefinitionNodeRef { node: _, target }) => target.into(),
Self::WithItem(WithItemDefinitionNodeRef {
node: _,
target,
is_async: _,
}) => target.into(),
Self::MatchPattern(MatchPatternDefinitionNodeRef { identifier, .. }) => {
identifier.into()
}
@ -534,6 +542,7 @@ pub enum AssignmentKind {
pub struct WithItemDefinitionKind {
node: AstNodeRef<ast::WithItem>,
target: AstNodeRef<ast::ExprName>,
is_async: bool,
}
impl WithItemDefinitionKind {
@ -544,6 +553,10 @@ impl WithItemDefinitionKind {
pub(crate) fn target(&self) -> &ast::ExprName {
self.target.node()
}
pub(crate) const fn is_async(&self) -> bool {
self.is_async
}
}
#[derive(Clone, Debug)]

View file

@ -1009,9 +1009,7 @@ impl<'db> Type<'db> {
}
}
/// Return the type resulting from calling an object of this type.
///
/// Returns `None` if `self` is not a callable type.
/// Return the outcome of calling an object of this type.
#[must_use]
fn call(self, db: &'db dyn Db, arg_types: &[Type<'db>]) -> CallOutcome<'db> {
match self {

View file

@ -359,6 +359,7 @@ impl<'db> TypeInferenceBuilder<'db> {
/// Get the already-inferred type of an expression node.
///
/// PANIC if no type has been inferred for this node.
#[track_caller]
fn expression_ty(&self, expr: &ast::Expr) -> Type<'db> {
self.types
.expression_ty(expr.scoped_ast_id(self.db, self.scope()))
@ -479,7 +480,12 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_parameter_with_default_definition(parameter_with_default, definition);
}
DefinitionKind::WithItem(with_item) => {
self.infer_with_item_definition(with_item.target(), with_item.node(), definition);
self.infer_with_item_definition(
with_item.target(),
with_item.node(),
with_item.is_async(),
definition,
);
}
DefinitionKind::MatchPattern(match_pattern) => {
self.infer_match_pattern_definition(
@ -973,18 +979,21 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_with_statement(&mut self, with_statement: &ast::StmtWith) {
let ast::StmtWith {
range: _,
is_async: _,
is_async,
items,
body,
} = with_statement;
for item in items {
let target = item.optional_vars.as_deref();
if let Some(ast::Expr::Name(name)) = target {
self.infer_definition(name);
} else {
// TODO infer definitions in unpacking assignment
self.infer_expression(&item.context_expr);
// Call into the context expression inference to validate that it evaluates
// to a valid context manager.
let context_expression_ty = self.infer_expression(&item.context_expr);
self.infer_context_expression(&item.context_expr, context_expression_ty, *is_async);
self.infer_optional_expression(target);
}
}
@ -996,20 +1005,116 @@ impl<'db> TypeInferenceBuilder<'db> {
&mut self,
target: &ast::ExprName,
with_item: &ast::WithItem,
is_async: bool,
definition: Definition<'db>,
) {
let expression = self.index.expression(&with_item.context_expr);
let result = infer_expression_types(self.db, expression);
self.extend(result);
let context_expr = self.index.expression(&with_item.context_expr);
self.extend(infer_expression_types(self.db, context_expr));
// TODO(dhruvmanila): The correct type inference here is the return type of the __enter__
// method of the context manager.
let context_expr_ty = self.expression_ty(&with_item.context_expr);
let target_ty = self.infer_context_expression(
&with_item.context_expr,
self.expression_ty(&with_item.context_expr),
is_async,
);
self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope()), context_expr_ty);
self.add_binding(target.into(), definition, context_expr_ty);
.insert(target.scoped_ast_id(self.db, self.scope()), target_ty);
self.add_binding(target.into(), definition, target_ty);
}
/// Infers the type of a context expression (`with expr`) and returns the target's type
///
/// Returns [`Type::Unknown`] if the context expression doesn't implement the context manager protocol.
///
/// ## Terminology
/// See [PEP343](https://peps.python.org/pep-0343/#standard-terminology).
fn infer_context_expression(
&mut self,
context_expression: &ast::Expr,
context_expression_ty: Type<'db>,
is_async: bool,
) -> Type<'db> {
// TODO: Handle async with statements (they use `aenter` and `aexit`)
if is_async {
return Type::Todo;
}
let context_manager_ty = context_expression_ty.to_meta_type(self.db);
let enter_ty = context_manager_ty.member(self.db, "__enter__");
let exit_ty = context_manager_ty.member(self.db, "__exit__");
// TODO: Make use of Protocols when we support it (the manager be assignable to `contextlib.AbstractContextManager`).
if enter_ty.is_unbound() && exit_ty.is_unbound() {
self.diagnostics.add(
context_expression.into(),
"invalid-context-manager",
format_args!(
"Object of type {} cannot be used with `with` because it doesn't implement `__enter__` and `__exit__`",
context_expression_ty.display(self.db)
),
);
Type::Unknown
} else if enter_ty.is_unbound() {
self.diagnostics.add(
context_expression.into(),
"invalid-context-manager",
format_args!(
"Object of type {} cannot be used with `with` because it doesn't implement `__enter__`",
context_expression_ty.display(self.db)
),
);
Type::Unknown
} else {
let target_ty = enter_ty
.call(self.db, &[context_expression_ty])
.return_ty_result(self.db, context_expression.into(), &mut self.diagnostics)
.unwrap_or_else(|err| {
self.diagnostics.add(
context_expression.into(),
"invalid-context-manager",
format_args!("
Object of type {context_expression} cannot be used with `with` because the method `__enter__` of type {enter_ty} is not callable",
context_expression = context_expression_ty.display(self.db),
enter_ty = enter_ty.display(self.db)
),
);
err.return_ty()
});
if exit_ty.is_unbound() {
self.diagnostics.add(
context_expression.into(),
"invalid-context-manager",
format_args!(
"Object of type {} cannot be used with `with` because it doesn't implement `__exit__`",
context_expression_ty.display(self.db)
),
);
}
// TODO: Use the `exit_ty` to determine if any raised exception is suppressed.
else if exit_ty
.call(
self.db,
&[context_manager_ty, Type::None, Type::None, Type::None],
)
.return_ty_result(self.db, context_expression.into(), &mut self.diagnostics)
.is_err()
{
self.diagnostics.add(
context_expression.into(),
"invalid-context-manager",
format_args!(
"Object of type {context_expression} cannot be used with `with` because the method `__exit__` of type {exit_ty} is not callable",
context_expression = context_expression_ty.display(self.db),
exit_ty = exit_ty.display(self.db),
),
);
}
target_ty
}
}
fn infer_except_handler_definition(