mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-29 21:35:58 +00:00
[red-knot] Handle context managers in (sync) with statements (#13998)
This commit is contained in:
parent
2d917d72f6
commit
76e4277696
6 changed files with 315 additions and 33 deletions
|
@ -0,0 +1,20 @@
|
|||
# Async with statements
|
||||
|
||||
## Basic `async with` statement
|
||||
|
||||
The type of the target variable in a `with` statement should be the return type from the context manager's `__aenter__` method.
|
||||
However, `async with` statements aren't supported yet. This test asserts that it doesn't emit any context manager-related errors.
|
||||
|
||||
```py
|
||||
class Target: ...
|
||||
|
||||
class Manager:
|
||||
async def __aenter__(self) -> Target:
|
||||
return Target()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback): ...
|
||||
|
||||
async def test():
|
||||
async with Manager() as f:
|
||||
reveal_type(f) # revealed: @Todo
|
||||
```
|
140
crates/red_knot_python_semantic/resources/mdtest/with/with.md
Normal file
140
crates/red_knot_python_semantic/resources/mdtest/with/with.md
Normal file
|
@ -0,0 +1,140 @@
|
|||
# With statements
|
||||
|
||||
## Basic `with` statement
|
||||
|
||||
The type of the target variable in a `with` statement is the return type from the context manager's `__enter__` method.
|
||||
|
||||
```py
|
||||
class Target: ...
|
||||
|
||||
class Manager:
|
||||
def __enter__(self) -> Target:
|
||||
return Target()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback): ...
|
||||
|
||||
with Manager() as f:
|
||||
reveal_type(f) # revealed: Target
|
||||
```
|
||||
|
||||
## Union context manager
|
||||
|
||||
```py
|
||||
def coinflip() -> bool:
|
||||
return True
|
||||
|
||||
class Manager1:
|
||||
def __enter__(self) -> str:
|
||||
return "foo"
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback): ...
|
||||
|
||||
class Manager2:
|
||||
def __enter__(self) -> int:
|
||||
return 42
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback): ...
|
||||
|
||||
context_expr = Manager1() if coinflip() else Manager2()
|
||||
|
||||
with context_expr as f:
|
||||
reveal_type(f) # revealed: str | int
|
||||
```
|
||||
|
||||
## Context manager without an `__enter__` or `__exit__` method
|
||||
|
||||
```py
|
||||
class Manager: ...
|
||||
|
||||
# error: [invalid-context-manager] "Object of type Manager cannot be used with `with` because it doesn't implement `__enter__` and `__exit__`"
|
||||
with Manager():
|
||||
...
|
||||
```
|
||||
|
||||
## Context manager without an `__enter__` method
|
||||
|
||||
```py
|
||||
class Manager:
|
||||
def __exit__(self, exc_tpe, exc_value, traceback): ...
|
||||
|
||||
# error: [invalid-context-manager] "Object of type Manager cannot be used with `with` because it doesn't implement `__enter__`"
|
||||
with Manager():
|
||||
...
|
||||
```
|
||||
|
||||
## Context manager without an `__exit__` method
|
||||
|
||||
```py
|
||||
class Manager:
|
||||
def __enter__(self): ...
|
||||
|
||||
# error: [invalid-context-manager] "Object of type Manager cannot be used with `with` because it doesn't implement `__exit__`"
|
||||
with Manager():
|
||||
...
|
||||
```
|
||||
|
||||
## Context manager with non-callable `__enter__` attribute
|
||||
|
||||
```py
|
||||
class Manager:
|
||||
__enter__ = 42
|
||||
|
||||
def __exit__(self, exc_tpe, exc_value, traceback): ...
|
||||
|
||||
# error: [invalid-context-manager] "Object of type Manager cannot be used with `with` because the method `__enter__` of type Literal[42] is not callable"
|
||||
with Manager():
|
||||
...
|
||||
```
|
||||
|
||||
## Context manager with non-callable `__exit__` attribute
|
||||
|
||||
```py
|
||||
class Manager:
|
||||
def __enter__(self) -> Self: ...
|
||||
|
||||
__exit__ = 32
|
||||
|
||||
# error: [invalid-context-manager] "Object of type Manager cannot be used with `with` because the method `__exit__` of type Literal[32] is not callable"
|
||||
with Manager():
|
||||
...
|
||||
```
|
||||
|
||||
## Context expression with non-callable union variants
|
||||
|
||||
```py
|
||||
def coinflip() -> bool:
|
||||
return True
|
||||
|
||||
class Manager1:
|
||||
def __enter__(self) -> str:
|
||||
return "foo"
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback): ...
|
||||
|
||||
class NotAContextManager: ...
|
||||
|
||||
context_expr = Manager1() if coinflip() else NotAContextManager()
|
||||
|
||||
# error: [invalid-context-manager] "Object of type Manager1 | NotAContextManager cannot be used with `with` because the method `__enter__` of type Literal[__enter__] | Unbound is not callable"
|
||||
# error: [invalid-context-manager] "Object of type Manager1 | NotAContextManager cannot be used with `with` because the method `__exit__` of type Literal[__exit__] | Unbound is not callable"
|
||||
with context_expr as f:
|
||||
reveal_type(f) # revealed: str | Unknown
|
||||
```
|
||||
|
||||
## Context expression with "sometimes" callable `__enter__` method
|
||||
|
||||
```py
|
||||
def coinflip() -> bool:
|
||||
return True
|
||||
|
||||
class Manager:
|
||||
if coinflip():
|
||||
def __enter__(self) -> str:
|
||||
return "abcd"
|
||||
|
||||
def __exit__(self, *args): ...
|
||||
|
||||
with Manager() as f:
|
||||
# TODO: This should emit an error that `__enter__` is possibly unbound.
|
||||
reveal_type(f) # revealed: str
|
||||
```
|
|
@ -734,12 +734,20 @@ where
|
|||
self.flow_merge(break_state);
|
||||
}
|
||||
}
|
||||
ast::Stmt::With(ast::StmtWith { items, body, .. }) => {
|
||||
ast::Stmt::With(ast::StmtWith {
|
||||
items,
|
||||
body,
|
||||
is_async,
|
||||
..
|
||||
}) => {
|
||||
for item in items {
|
||||
self.visit_expr(&item.context_expr);
|
||||
if let Some(optional_vars) = item.optional_vars.as_deref() {
|
||||
self.add_standalone_expression(&item.context_expr);
|
||||
self.push_assignment(item.into());
|
||||
self.push_assignment(CurrentAssignment::WithItem {
|
||||
item,
|
||||
is_async: *is_async,
|
||||
});
|
||||
self.visit_expr(optional_vars);
|
||||
self.pop_assignment();
|
||||
}
|
||||
|
@ -1011,12 +1019,13 @@ where
|
|||
},
|
||||
);
|
||||
}
|
||||
Some(CurrentAssignment::WithItem(with_item)) => {
|
||||
Some(CurrentAssignment::WithItem { item, is_async }) => {
|
||||
self.add_definition(
|
||||
symbol,
|
||||
WithItemDefinitionNodeRef {
|
||||
node: with_item,
|
||||
node: item,
|
||||
target: name_node,
|
||||
is_async,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -1232,7 +1241,10 @@ enum CurrentAssignment<'a> {
|
|||
node: &'a ast::Comprehension,
|
||||
first: bool,
|
||||
},
|
||||
WithItem(&'a ast::WithItem),
|
||||
WithItem {
|
||||
item: &'a ast::WithItem,
|
||||
is_async: bool,
|
||||
},
|
||||
}
|
||||
|
||||
impl<'a> From<&'a ast::StmtAnnAssign> for CurrentAssignment<'a> {
|
||||
|
@ -1259,12 +1271,6 @@ impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a ast::WithItem> for CurrentAssignment<'a> {
|
||||
fn from(value: &'a ast::WithItem) -> Self {
|
||||
Self::WithItem(value)
|
||||
}
|
||||
}
|
||||
|
||||
struct CurrentMatchCase<'a> {
|
||||
/// The pattern that's part of the current match case.
|
||||
pattern: &'a ast::Pattern,
|
||||
|
|
|
@ -176,6 +176,7 @@ pub(crate) struct AssignmentDefinitionNodeRef<'a> {
|
|||
pub(crate) struct WithItemDefinitionNodeRef<'a> {
|
||||
pub(crate) node: &'a ast::WithItem,
|
||||
pub(crate) target: &'a ast::ExprName,
|
||||
pub(crate) is_async: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
|
@ -277,12 +278,15 @@ impl DefinitionNodeRef<'_> {
|
|||
DefinitionKind::ParameterWithDefault(AstNodeRef::new(parsed, parameter))
|
||||
}
|
||||
},
|
||||
DefinitionNodeRef::WithItem(WithItemDefinitionNodeRef { node, target }) => {
|
||||
DefinitionKind::WithItem(WithItemDefinitionKind {
|
||||
node: AstNodeRef::new(parsed.clone(), node),
|
||||
target: AstNodeRef::new(parsed, target),
|
||||
})
|
||||
}
|
||||
DefinitionNodeRef::WithItem(WithItemDefinitionNodeRef {
|
||||
node,
|
||||
target,
|
||||
is_async,
|
||||
}) => DefinitionKind::WithItem(WithItemDefinitionKind {
|
||||
node: AstNodeRef::new(parsed.clone(), node),
|
||||
target: AstNodeRef::new(parsed, target),
|
||||
is_async,
|
||||
}),
|
||||
DefinitionNodeRef::MatchPattern(MatchPatternDefinitionNodeRef {
|
||||
pattern,
|
||||
identifier,
|
||||
|
@ -329,7 +333,11 @@ impl DefinitionNodeRef<'_> {
|
|||
ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
|
||||
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
|
||||
},
|
||||
Self::WithItem(WithItemDefinitionNodeRef { node: _, target }) => target.into(),
|
||||
Self::WithItem(WithItemDefinitionNodeRef {
|
||||
node: _,
|
||||
target,
|
||||
is_async: _,
|
||||
}) => target.into(),
|
||||
Self::MatchPattern(MatchPatternDefinitionNodeRef { identifier, .. }) => {
|
||||
identifier.into()
|
||||
}
|
||||
|
@ -534,6 +542,7 @@ pub enum AssignmentKind {
|
|||
pub struct WithItemDefinitionKind {
|
||||
node: AstNodeRef<ast::WithItem>,
|
||||
target: AstNodeRef<ast::ExprName>,
|
||||
is_async: bool,
|
||||
}
|
||||
|
||||
impl WithItemDefinitionKind {
|
||||
|
@ -544,6 +553,10 @@ impl WithItemDefinitionKind {
|
|||
pub(crate) fn target(&self) -> &ast::ExprName {
|
||||
self.target.node()
|
||||
}
|
||||
|
||||
pub(crate) const fn is_async(&self) -> bool {
|
||||
self.is_async
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
|
|
@ -1009,9 +1009,7 @@ impl<'db> Type<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Return the type resulting from calling an object of this type.
|
||||
///
|
||||
/// Returns `None` if `self` is not a callable type.
|
||||
/// Return the outcome of calling an object of this type.
|
||||
#[must_use]
|
||||
fn call(self, db: &'db dyn Db, arg_types: &[Type<'db>]) -> CallOutcome<'db> {
|
||||
match self {
|
||||
|
|
|
@ -359,6 +359,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
/// Get the already-inferred type of an expression node.
|
||||
///
|
||||
/// PANIC if no type has been inferred for this node.
|
||||
#[track_caller]
|
||||
fn expression_ty(&self, expr: &ast::Expr) -> Type<'db> {
|
||||
self.types
|
||||
.expression_ty(expr.scoped_ast_id(self.db, self.scope()))
|
||||
|
@ -479,7 +480,12 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
self.infer_parameter_with_default_definition(parameter_with_default, definition);
|
||||
}
|
||||
DefinitionKind::WithItem(with_item) => {
|
||||
self.infer_with_item_definition(with_item.target(), with_item.node(), definition);
|
||||
self.infer_with_item_definition(
|
||||
with_item.target(),
|
||||
with_item.node(),
|
||||
with_item.is_async(),
|
||||
definition,
|
||||
);
|
||||
}
|
||||
DefinitionKind::MatchPattern(match_pattern) => {
|
||||
self.infer_match_pattern_definition(
|
||||
|
@ -973,18 +979,21 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
fn infer_with_statement(&mut self, with_statement: &ast::StmtWith) {
|
||||
let ast::StmtWith {
|
||||
range: _,
|
||||
is_async: _,
|
||||
is_async,
|
||||
items,
|
||||
body,
|
||||
} = with_statement;
|
||||
|
||||
for item in items {
|
||||
let target = item.optional_vars.as_deref();
|
||||
if let Some(ast::Expr::Name(name)) = target {
|
||||
self.infer_definition(name);
|
||||
} else {
|
||||
// TODO infer definitions in unpacking assignment
|
||||
self.infer_expression(&item.context_expr);
|
||||
|
||||
// Call into the context expression inference to validate that it evaluates
|
||||
// to a valid context manager.
|
||||
let context_expression_ty = self.infer_expression(&item.context_expr);
|
||||
self.infer_context_expression(&item.context_expr, context_expression_ty, *is_async);
|
||||
self.infer_optional_expression(target);
|
||||
}
|
||||
}
|
||||
|
@ -996,20 +1005,116 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
&mut self,
|
||||
target: &ast::ExprName,
|
||||
with_item: &ast::WithItem,
|
||||
is_async: bool,
|
||||
definition: Definition<'db>,
|
||||
) {
|
||||
let expression = self.index.expression(&with_item.context_expr);
|
||||
let result = infer_expression_types(self.db, expression);
|
||||
self.extend(result);
|
||||
let context_expr = self.index.expression(&with_item.context_expr);
|
||||
self.extend(infer_expression_types(self.db, context_expr));
|
||||
|
||||
// TODO(dhruvmanila): The correct type inference here is the return type of the __enter__
|
||||
// method of the context manager.
|
||||
let context_expr_ty = self.expression_ty(&with_item.context_expr);
|
||||
let target_ty = self.infer_context_expression(
|
||||
&with_item.context_expr,
|
||||
self.expression_ty(&with_item.context_expr),
|
||||
is_async,
|
||||
);
|
||||
|
||||
self.types
|
||||
.expressions
|
||||
.insert(target.scoped_ast_id(self.db, self.scope()), context_expr_ty);
|
||||
self.add_binding(target.into(), definition, context_expr_ty);
|
||||
.insert(target.scoped_ast_id(self.db, self.scope()), target_ty);
|
||||
self.add_binding(target.into(), definition, target_ty);
|
||||
}
|
||||
|
||||
/// Infers the type of a context expression (`with expr`) and returns the target's type
|
||||
///
|
||||
/// Returns [`Type::Unknown`] if the context expression doesn't implement the context manager protocol.
|
||||
///
|
||||
/// ## Terminology
|
||||
/// See [PEP343](https://peps.python.org/pep-0343/#standard-terminology).
|
||||
fn infer_context_expression(
|
||||
&mut self,
|
||||
context_expression: &ast::Expr,
|
||||
context_expression_ty: Type<'db>,
|
||||
is_async: bool,
|
||||
) -> Type<'db> {
|
||||
// TODO: Handle async with statements (they use `aenter` and `aexit`)
|
||||
if is_async {
|
||||
return Type::Todo;
|
||||
}
|
||||
|
||||
let context_manager_ty = context_expression_ty.to_meta_type(self.db);
|
||||
|
||||
let enter_ty = context_manager_ty.member(self.db, "__enter__");
|
||||
let exit_ty = context_manager_ty.member(self.db, "__exit__");
|
||||
|
||||
// TODO: Make use of Protocols when we support it (the manager be assignable to `contextlib.AbstractContextManager`).
|
||||
if enter_ty.is_unbound() && exit_ty.is_unbound() {
|
||||
self.diagnostics.add(
|
||||
context_expression.into(),
|
||||
"invalid-context-manager",
|
||||
format_args!(
|
||||
"Object of type {} cannot be used with `with` because it doesn't implement `__enter__` and `__exit__`",
|
||||
context_expression_ty.display(self.db)
|
||||
),
|
||||
);
|
||||
Type::Unknown
|
||||
} else if enter_ty.is_unbound() {
|
||||
self.diagnostics.add(
|
||||
context_expression.into(),
|
||||
"invalid-context-manager",
|
||||
format_args!(
|
||||
"Object of type {} cannot be used with `with` because it doesn't implement `__enter__`",
|
||||
context_expression_ty.display(self.db)
|
||||
),
|
||||
);
|
||||
Type::Unknown
|
||||
} else {
|
||||
let target_ty = enter_ty
|
||||
.call(self.db, &[context_expression_ty])
|
||||
.return_ty_result(self.db, context_expression.into(), &mut self.diagnostics)
|
||||
.unwrap_or_else(|err| {
|
||||
self.diagnostics.add(
|
||||
context_expression.into(),
|
||||
"invalid-context-manager",
|
||||
format_args!("
|
||||
Object of type {context_expression} cannot be used with `with` because the method `__enter__` of type {enter_ty} is not callable",
|
||||
context_expression = context_expression_ty.display(self.db),
|
||||
enter_ty = enter_ty.display(self.db)
|
||||
),
|
||||
);
|
||||
err.return_ty()
|
||||
});
|
||||
|
||||
if exit_ty.is_unbound() {
|
||||
self.diagnostics.add(
|
||||
context_expression.into(),
|
||||
"invalid-context-manager",
|
||||
format_args!(
|
||||
"Object of type {} cannot be used with `with` because it doesn't implement `__exit__`",
|
||||
context_expression_ty.display(self.db)
|
||||
),
|
||||
);
|
||||
}
|
||||
// TODO: Use the `exit_ty` to determine if any raised exception is suppressed.
|
||||
else if exit_ty
|
||||
.call(
|
||||
self.db,
|
||||
&[context_manager_ty, Type::None, Type::None, Type::None],
|
||||
)
|
||||
.return_ty_result(self.db, context_expression.into(), &mut self.diagnostics)
|
||||
.is_err()
|
||||
{
|
||||
self.diagnostics.add(
|
||||
context_expression.into(),
|
||||
"invalid-context-manager",
|
||||
format_args!(
|
||||
"Object of type {context_expression} cannot be used with `with` because the method `__exit__` of type {exit_ty} is not callable",
|
||||
context_expression = context_expression_ty.display(self.db),
|
||||
exit_ty = exit_ty.display(self.db),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
target_ty
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_except_handler_definition(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue