[red-knot] Support unpacking with target (#16469)

## Summary

Resolves #16365

Add support for unpacking `with` statement targets.

## Test Plan

Added some test cases, alike the ones added by #15058.

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Eric Mark Martin 2025-03-07 21:36:35 -05:00 committed by GitHub
parent 820a31af5d
commit 24c8b1242e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 505 additions and 219 deletions

View file

@ -216,6 +216,17 @@ impl SourceOrderVisitor<'_> for PullTypesVisitor<'_> {
self.visit_body(&for_stmt.orelse);
return;
}
Stmt::With(with_stmt) => {
for item in &with_stmt.items {
if let Some(target) = &item.optional_vars {
self.visit_target(target);
}
self.visit_expr(&item.context_expr);
}
self.visit_body(&with_stmt.body);
return;
}
Stmt::AnnAssign(_)
| Stmt::Return(_)
| Stmt::Delete(_)
@ -223,7 +234,6 @@ impl SourceOrderVisitor<'_> for PullTypesVisitor<'_> {
| Stmt::TypeAlias(_)
| Stmt::While(_)
| Stmt::If(_)
| Stmt::With(_)
| Stmt::Match(_)
| Stmt::Raise(_)
| Stmt::Try(_)

View file

@ -358,9 +358,25 @@ class C:
c_instance = C()
# TODO: Should be `Unknown | int | None`
# error: [unresolved-attribute]
reveal_type(c_instance.x) # revealed: Unknown
reveal_type(c_instance.x) # revealed: Unknown | int | None
```
#### Attributes defined in `with` statements, but with unpacking
```py
class ContextManager:
def __enter__(self) -> tuple[int | None, int]: ...
def __exit__(self, exc_type, exc_value, traceback) -> None: ...
class C:
def __init__(self) -> None:
with ContextManager() as (self.x, self.y):
pass
c_instance = C()
reveal_type(c_instance.x) # revealed: Unknown | int | None
reveal_type(c_instance.y) # revealed: Unknown | int
```
#### Attributes defined in comprehensions

View file

@ -613,3 +613,98 @@ def _(arg: tuple[tuple[int, str], Iterable]):
reveal_type(a) # revealed: int | bytes
reveal_type(b) # revealed: str | bytes
```
## With statement
Unpacking in a `with` statement.
### Same types
```py
class ContextManager:
def __enter__(self) -> tuple[int, int]:
return (1, 2)
def __exit__(self, exc_type, exc_value, traceback) -> None:
pass
with ContextManager() as (a, b):
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int
```
### Mixed types
```py
class ContextManager:
def __enter__(self) -> tuple[int, str]:
return (1, "a")
def __exit__(self, exc_type, exc_value, traceback) -> None:
pass
with ContextManager() as (a, b):
reveal_type(a) # revealed: int
reveal_type(b) # revealed: str
```
### Nested
```py
class ContextManager:
def __enter__(self) -> tuple[int, tuple[str, bytes]]:
return (1, ("a", b"bytes"))
def __exit__(self, exc_type, exc_value, traceback) -> None:
pass
with ContextManager() as (a, (b, c)):
reveal_type(a) # revealed: int
reveal_type(b) # revealed: str
reveal_type(c) # revealed: bytes
```
### Starred expression
```py
class ContextManager:
def __enter__(self) -> tuple[int, int, int]:
return (1, 2, 3)
def __exit__(self, exc_type, exc_value, traceback) -> None:
pass
with ContextManager() as (a, *b):
reveal_type(a) # revealed: int
# TODO: Should be list[int] once support for assigning to starred expression is added
reveal_type(b) # revealed: @Todo(starred unpacking)
```
### Unbound context manager expression
```py
# TODO: should only be one diagnostic
# error: [unresolved-reference] "Name `nonexistant` used when not defined"
# error: [unresolved-reference] "Name `nonexistant` used when not defined"
# error: [unresolved-reference] "Name `nonexistant` used when not defined"
with nonexistant as (x, y):
reveal_type(x) # revealed: Unknown
reveal_type(y) # revealed: Unknown
```
### Invalid unpacking
```py
class ContextManager:
def __enter__(self) -> tuple[int, str]:
return (1, "a")
def __exit__(self, *args) -> None:
pass
# error: [invalid-assignment] "Not enough values to unpack (expected 3, got 2)"
with ContextManager() as (a, b, c):
reveal_type(a) # revealed: Unknown
reveal_type(b) # revealed: Unknown
reveal_type(c) # revealed: Unknown
```

View file

@ -45,7 +45,7 @@ def _(flag: bool):
```py
class Manager: ...
# error: [invalid-context-manager] "Object of type `Manager` cannot be used with `with` because it doesn't implement `__enter__` and `__exit__`"
# error: [invalid-context-manager] "Object of type `Manager` cannot be used with `with` because it does not implement `__enter__` and `__exit__`"
with Manager():
...
```
@ -56,7 +56,7 @@ with Manager():
class Manager:
def __exit__(self, exc_tpe, exc_value, traceback): ...
# error: [invalid-context-manager] "Object of type `Manager` cannot be used with `with` because it doesn't implement `__enter__`"
# error: [invalid-context-manager] "Object of type `Manager` cannot be used with `with` because it does not implement `__enter__`"
with Manager():
...
```
@ -67,7 +67,7 @@ with Manager():
class Manager:
def __enter__(self): ...
# error: [invalid-context-manager] "Object of type `Manager` cannot be used with `with` because it doesn't implement `__exit__`"
# error: [invalid-context-manager] "Object of type `Manager` cannot be used with `with` because it does not implement `__exit__`"
with Manager():
...
```
@ -113,8 +113,7 @@ def _(flag: bool):
class NotAContextManager: ...
context_expr = Manager1() if flag else NotAContextManager()
# error: [invalid-context-manager] "Object of type `Manager1 | NotAContextManager` cannot be used with `with` because the method `__enter__` is possibly unbound"
# error: [invalid-context-manager] "Object of type `Manager1 | NotAContextManager` cannot be used with `with` because the method `__exit__` is possibly unbound"
# error: [invalid-context-manager] "Object of type `Manager1 | NotAContextManager` cannot be used with `with` because the methods `__enter__` and `__exit__` are possibly unbound"
with context_expr as f:
reveal_type(f) # revealed: str
```

View file

@ -45,7 +45,7 @@ pub struct AstNodeRef<T> {
#[allow(unsafe_code)]
impl<T> AstNodeRef<T> {
/// Creates a new `AstNodeRef` that reference `node`. The `parsed` is the [`ParsedModule`] to
/// Creates a new `AstNodeRef` that references `node`. The `parsed` is the [`ParsedModule`] to
/// which the `AstNodeRef` belongs.
///
/// ## Safety

View file

@ -22,6 +22,10 @@ pub(crate) enum AttributeAssignment<'db> {
/// `for self.x in <iterable>`.
Iterable { iterable: Expression<'db> },
/// An attribute assignment where the expression to be assigned is a context manager, for example
/// `with <context_manager> as self.x`.
ContextManager { context_manager: Expression<'db> },
/// An attribute assignment where the left-hand side is an unpacking expression,
/// e.g. `self.x, self.y = <value>`.
Unpack {

View file

@ -1032,6 +1032,7 @@ where
self.db,
self.file,
self.current_scope(),
// SAFETY: `target` belongs to the `self.module` tree
#[allow(unsafe_code)]
unsafe {
AstNodeRef::new(self.module.clone(), target)
@ -1262,16 +1263,64 @@ where
is_async,
..
}) => {
for item in items {
self.visit_expr(&item.context_expr);
if let Some(optional_vars) = item.optional_vars.as_deref() {
self.add_standalone_expression(&item.context_expr);
self.push_assignment(CurrentAssignment::WithItem {
item,
is_async: *is_async,
});
for item @ ruff_python_ast::WithItem {
range: _,
context_expr,
optional_vars,
} in items
{
self.visit_expr(context_expr);
if let Some(optional_vars) = optional_vars.as_deref() {
let context_manager = self.add_standalone_expression(context_expr);
let current_assignment = match optional_vars {
ast::Expr::Tuple(_) | ast::Expr::List(_) => {
Some(CurrentAssignment::WithItem {
item,
first: true,
is_async: *is_async,
unpack: Some(Unpack::new(
self.db,
self.file,
self.current_scope(),
// SAFETY: the node `optional_vars` belongs to the `self.module` tree
#[allow(unsafe_code)]
unsafe {
AstNodeRef::new(self.module.clone(), optional_vars)
},
UnpackValue::ContextManager(context_manager),
countme::Count::default(),
)),
})
}
ast::Expr::Name(_) => Some(CurrentAssignment::WithItem {
item,
is_async: *is_async,
unpack: None,
// `false` is arbitrary here---we don't actually use it other than in the actual unpacks
first: false,
}),
ast::Expr::Attribute(ast::ExprAttribute {
value: object,
attr,
..
}) => {
self.register_attribute_assignment(
object,
attr,
AttributeAssignment::ContextManager { context_manager },
);
None
}
_ => None,
};
if let Some(current_assignment) = current_assignment {
self.push_assignment(current_assignment);
}
self.visit_expr(optional_vars);
self.pop_assignment();
if current_assignment.is_some() {
self.pop_assignment();
}
}
}
self.visit_body(body);
@ -1304,6 +1353,7 @@ where
self.db,
self.file,
self.current_scope(),
// SAFETY: the node `target` belongs to the `self.module` tree
#[allow(unsafe_code)]
unsafe {
AstNodeRef::new(self.module.clone(), target)
@ -1631,12 +1681,19 @@ where
},
);
}
Some(CurrentAssignment::WithItem { item, is_async }) => {
Some(CurrentAssignment::WithItem {
item,
first,
is_async,
unpack,
}) => {
self.add_definition(
symbol,
WithItemDefinitionNodeRef {
node: item,
target: name_node,
unpack,
context_expr: &item.context_expr,
name: name_node,
first,
is_async,
},
);
@ -1646,7 +1703,9 @@ where
}
if let Some(
CurrentAssignment::Assign { first, .. } | CurrentAssignment::For { first, .. },
CurrentAssignment::Assign { first, .. }
| CurrentAssignment::For { first, .. }
| CurrentAssignment::WithItem { first, .. },
) = self.current_assignment_mut()
{
*first = false;
@ -1826,6 +1885,10 @@ where
| CurrentAssignment::For {
unpack: Some(unpack),
..
}
| CurrentAssignment::WithItem {
unpack: Some(unpack),
..
},
) = self.current_assignment()
{
@ -1919,7 +1982,9 @@ enum CurrentAssignment<'a> {
},
WithItem {
item: &'a ast::WithItem,
first: bool,
is_async: bool,
unpack: Option<Unpack<'a>>,
},
}

View file

@ -201,8 +201,10 @@ pub(crate) struct AssignmentDefinitionNodeRef<'a> {
#[derive(Copy, Clone, Debug)]
pub(crate) struct WithItemDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::WithItem,
pub(crate) target: &'a ast::ExprName,
pub(crate) unpack: Option<Unpack<'a>>,
pub(crate) context_expr: &'a ast::Expr,
pub(crate) name: &'a ast::ExprName,
pub(crate) first: bool,
pub(crate) is_async: bool,
}
@ -323,12 +325,16 @@ impl<'db> DefinitionNodeRef<'db> {
DefinitionKind::Parameter(AstNodeRef::new(parsed, parameter))
}
DefinitionNodeRef::WithItem(WithItemDefinitionNodeRef {
node,
target,
unpack,
context_expr,
name,
first,
is_async,
}) => DefinitionKind::WithItem(WithItemDefinitionKind {
node: AstNodeRef::new(parsed.clone(), node),
target: AstNodeRef::new(parsed, target),
target: TargetKind::from(unpack),
context_expr: AstNodeRef::new(parsed.clone(), context_expr),
name: AstNodeRef::new(parsed, name),
first,
is_async,
}),
DefinitionNodeRef::MatchPattern(MatchPatternDefinitionNodeRef {
@ -394,10 +400,12 @@ impl<'db> DefinitionNodeRef<'db> {
Self::VariadicKeywordParameter(node) => node.into(),
Self::Parameter(node) => node.into(),
Self::WithItem(WithItemDefinitionNodeRef {
node: _,
target,
unpack: _,
context_expr: _,
first: _,
is_async: _,
}) => target.into(),
name,
}) => name.into(),
Self::MatchPattern(MatchPatternDefinitionNodeRef { identifier, .. }) => {
identifier.into()
}
@ -467,7 +475,7 @@ pub enum DefinitionKind<'db> {
VariadicPositionalParameter(AstNodeRef<ast::Parameter>),
VariadicKeywordParameter(AstNodeRef<ast::Parameter>),
Parameter(AstNodeRef<ast::ParameterWithDefault>),
WithItem(WithItemDefinitionKind),
WithItem(WithItemDefinitionKind<'db>),
MatchPattern(MatchPatternDefinitionKind),
ExceptHandler(ExceptHandlerDefinitionKind),
TypeVar(AstNodeRef<ast::TypeParamTypeVar>),
@ -506,7 +514,7 @@ impl DefinitionKind<'_> {
DefinitionKind::VariadicPositionalParameter(parameter) => parameter.name.range(),
DefinitionKind::VariadicKeywordParameter(parameter) => parameter.name.range(),
DefinitionKind::Parameter(parameter) => parameter.parameter.name.range(),
DefinitionKind::WithItem(with_item) => with_item.target().range(),
DefinitionKind::WithItem(with_item) => with_item.name().range(),
DefinitionKind::MatchPattern(match_pattern) => match_pattern.identifier.range(),
DefinitionKind::ExceptHandler(handler) => handler.node().range(),
DefinitionKind::TypeVar(type_var) => type_var.name.range(),
@ -688,19 +696,29 @@ impl<'db> AssignmentDefinitionKind<'db> {
}
#[derive(Clone, Debug)]
pub struct WithItemDefinitionKind {
node: AstNodeRef<ast::WithItem>,
target: AstNodeRef<ast::ExprName>,
pub struct WithItemDefinitionKind<'db> {
target: TargetKind<'db>,
context_expr: AstNodeRef<ast::Expr>,
name: AstNodeRef<ast::ExprName>,
first: bool,
is_async: bool,
}
impl WithItemDefinitionKind {
pub(crate) fn node(&self) -> &ast::WithItem {
self.node.node()
impl<'db> WithItemDefinitionKind<'db> {
pub(crate) fn context_expr(&self) -> &ast::Expr {
self.context_expr.node()
}
pub(crate) fn target(&self) -> &ast::ExprName {
self.target.node()
pub(crate) fn target(&self) -> TargetKind<'db> {
self.target
}
pub(crate) fn name(&self) -> &ast::ExprName {
self.name.node()
}
pub(crate) const fn is_first(&self) -> bool {
self.first
}
pub(crate) const fn is_async(&self) -> bool {

View file

@ -4,7 +4,7 @@ use std::str::FromStr;
use bitflags::bitflags;
use call::{CallDunderError, CallError};
use context::InferContext;
use diagnostic::NOT_ITERABLE;
use diagnostic::{INVALID_CONTEXT_MANAGER, NOT_ITERABLE};
use ruff_db::files::File;
use ruff_python_ast as ast;
use ruff_python_ast::name::Name;
@ -2841,6 +2841,52 @@ impl<'db> Type<'db> {
})
}
/// Returns the type bound from a context manager with type `self`.
///
/// This method should only be used outside of type checking because it omits any errors.
/// For type checking, use [`try_enter`](Self::try_enter) instead.
fn enter(self, db: &'db dyn Db) -> Type<'db> {
self.try_enter(db)
.unwrap_or_else(|err| err.fallback_enter_type(db))
}
/// Given the type of an object that is used as a context manager (i.e. in a `with` statement),
/// return the return type of its `__enter__` method, which is bound to any potential targets.
///
/// E.g., for the following `with` statement, given the type of `x`, infer the type of `y`:
/// ```python
/// with x as y:
/// pass
/// ```
fn try_enter(self, db: &'db dyn Db) -> Result<Type<'db>, ContextManagerError<'db>> {
let enter = self.try_call_dunder(db, "__enter__", &CallArguments::none());
let exit = self.try_call_dunder(
db,
"__exit__",
&CallArguments::positional([Type::none(db), Type::none(db), Type::none(db)]),
);
// TODO: Make use of Protocols when we support it (the manager be assignable to `contextlib.AbstractContextManager`).
let result = match (enter, exit) {
(Ok(enter), Ok(_)) => Ok(enter.return_type(db)),
(Ok(enter), Err(exit_error)) => Err(ContextManagerErrorKind::Exit {
enter_return_type: enter.return_type(db),
exit_error,
}),
// TODO: Use the `exit_ty` to determine if any raised exception is suppressed.
(Err(enter_error), Ok(_)) => Err(ContextManagerErrorKind::Enter(enter_error)),
(Err(enter_error), Err(exit_error)) => Err(ContextManagerErrorKind::EnterAndExit {
enter_error,
exit_error,
}),
};
result.map_err(|error_kind| ContextManagerError {
context_manager_type: self,
error_kind,
})
}
#[must_use]
pub fn to_instance(&self, db: &'db dyn Db) -> Type<'db> {
match self {
@ -3374,6 +3420,135 @@ pub enum TypeVarBoundOrConstraints<'db> {
Constraints(TupleType<'db>),
}
/// Error returned if a type is not (or may not be) a context manager.
#[derive(Debug)]
struct ContextManagerError<'db> {
/// The type of the object that the analysed code attempted to use as a context manager.
context_manager_type: Type<'db>,
/// The precise kind of error encountered when trying to use the type as a context manager.
error_kind: ContextManagerErrorKind<'db>,
}
impl<'db> ContextManagerError<'db> {
fn enter_type(&self, db: &'db dyn Db) -> Option<Type<'db>> {
self.error_kind.enter_type(db)
}
fn fallback_enter_type(&self, db: &'db dyn Db) -> Type<'db> {
self.enter_type(db).unwrap_or(Type::unknown())
}
/// Reports the diagnostic for this error
fn report_diagnostic(
&self,
context: &InferContext<'db>,
context_manager_node: ast::AnyNodeRef,
) {
self.error_kind
.report_diagnostic(context, self.context_manager_type, context_manager_node);
}
}
#[derive(Debug)]
enum ContextManagerErrorKind<'db> {
Exit {
enter_return_type: Type<'db>,
exit_error: CallDunderError<'db>,
},
Enter(CallDunderError<'db>),
EnterAndExit {
enter_error: CallDunderError<'db>,
exit_error: CallDunderError<'db>,
},
}
impl<'db> ContextManagerErrorKind<'db> {
fn enter_type(&self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
ContextManagerErrorKind::Exit {
enter_return_type,
exit_error: _,
} => Some(*enter_return_type),
ContextManagerErrorKind::Enter(enter_error)
| ContextManagerErrorKind::EnterAndExit {
enter_error,
exit_error: _,
} => match enter_error {
CallDunderError::PossiblyUnbound(call_outcome) => {
Some(call_outcome.return_type(db))
}
CallDunderError::Call(call_error) => call_error.return_type(db),
CallDunderError::MethodNotAvailable => None,
},
}
}
fn report_diagnostic(
&self,
context: &InferContext<'db>,
context_expression_ty: Type<'db>,
context_expression_node: ast::AnyNodeRef,
) {
let format_call_dunder_error = |call_dunder_error: &CallDunderError<'db>, name: &str| {
match call_dunder_error {
CallDunderError::MethodNotAvailable => format!("it does not implement `{name}`"),
CallDunderError::PossiblyUnbound(_) => {
format!("the method `{name}` is possibly unbound")
}
// TODO: Use more specific error messages for the different error cases.
// E.g. hint toward the union variant that doesn't correctly implement enter,
// distinguish between a not callable `__enter__` attribute and a wrong signature.
CallDunderError::Call(_) => format!("it does not correctly implement `{name}`"),
}
};
let format_call_dunder_errors = |error_a: &CallDunderError<'db>,
name_a: &str,
error_b: &CallDunderError<'db>,
name_b: &str| {
match (error_a, error_b) {
(CallDunderError::PossiblyUnbound(_), CallDunderError::PossiblyUnbound(_)) => {
format!("the methods `{name_a}` and `{name_b}` are possibly unbound")
}
(CallDunderError::MethodNotAvailable, CallDunderError::MethodNotAvailable) => {
format!("it does not implement `{name_a}` and `{name_b}`")
}
(CallDunderError::Call(_), CallDunderError::Call(_)) => {
format!("it does not correctly implement `{name_a}` or `{name_b}`")
}
(_, _) => format!(
"{format_a}, and {format_b}",
format_a = format_call_dunder_error(error_a, name_a),
format_b = format_call_dunder_error(error_b, name_b)
),
}
};
let db = context.db();
let formatted_errors = match self {
ContextManagerErrorKind::Exit {
enter_return_type: _,
exit_error,
} => format_call_dunder_error(exit_error, "__exit__"),
ContextManagerErrorKind::Enter(enter_error) => {
format_call_dunder_error(enter_error, "__enter__")
}
ContextManagerErrorKind::EnterAndExit {
enter_error,
exit_error,
} => format_call_dunder_errors(enter_error, "__enter__", exit_error, "__exit__"),
};
context.report_lint(&INVALID_CONTEXT_MANAGER, context_expression_node,
format_args!("Object of type `{context_expression}` cannot be used with `with` because {formatted_errors}",
context_expression = context_expression_ty.display(db)
),
);
}
}
/// Error returned if a type is not (or may not be) iterable.
#[derive(Debug)]
struct IterationError<'db> {

View file

@ -513,6 +513,16 @@ impl<'db> Class<'db> {
union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
}
AttributeAssignment::ContextManager { context_manager } => {
// We found an attribute assignment like:
//
// with <context_manager> as self.name:
let context_ty = infer_expression_type(db, *context_manager);
let inferred_ty = context_ty.enter(db);
union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
}
AttributeAssignment::Unpack {
attribute_expression_id,
unpack,

View file

@ -43,7 +43,7 @@ use crate::module_resolver::{file_to_module, resolve_module};
use crate::semantic_index::ast_ids::{HasScopedExpressionId, HasScopedUseId, ScopedExpressionId};
use crate::semantic_index::definition::{
AssignmentDefinitionKind, Definition, DefinitionKind, DefinitionNodeKey,
ExceptHandlerDefinitionKind, ForStmtDefinitionKind, TargetKind,
ExceptHandlerDefinitionKind, ForStmtDefinitionKind, TargetKind, WithItemDefinitionKind,
};
use crate::semantic_index::expression::{Expression, ExpressionKind};
use crate::semantic_index::semantic_index;
@ -60,10 +60,10 @@ use crate::types::diagnostic::{
report_invalid_attribute_assignment, report_unresolved_module, TypeCheckDiagnostics,
CALL_NON_CALLABLE, CALL_POSSIBLY_UNBOUND_METHOD, CONFLICTING_DECLARATIONS,
CONFLICTING_METACLASS, CYCLIC_CLASS_DEFINITION, DIVISION_BY_ZERO, DUPLICATE_BASE,
INCONSISTENT_MRO, INVALID_ATTRIBUTE_ACCESS, INVALID_BASE, INVALID_CONTEXT_MANAGER,
INVALID_DECLARATION, INVALID_PARAMETER_DEFAULT, INVALID_TYPE_FORM,
INVALID_TYPE_VARIABLE_CONSTRAINTS, POSSIBLY_UNBOUND_ATTRIBUTE, POSSIBLY_UNBOUND_IMPORT,
UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_IMPORT, UNSUPPORTED_OPERATOR,
INCONSISTENT_MRO, INVALID_ATTRIBUTE_ACCESS, INVALID_BASE, INVALID_DECLARATION,
INVALID_PARAMETER_DEFAULT, INVALID_TYPE_FORM, INVALID_TYPE_VARIABLE_CONSTRAINTS,
POSSIBLY_UNBOUND_ATTRIBUTE, POSSIBLY_UNBOUND_IMPORT, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE,
UNRESOLVED_IMPORT, UNSUPPORTED_OPERATOR,
};
use crate::types::mro::MroErrorKind;
use crate::types::unpacker::{UnpackResult, Unpacker};
@ -839,13 +839,8 @@ impl<'db> TypeInferenceBuilder<'db> {
DefinitionKind::Parameter(parameter_with_default) => {
self.infer_parameter_definition(parameter_with_default, definition);
}
DefinitionKind::WithItem(with_item) => {
self.infer_with_item_definition(
with_item.target(),
with_item.node(),
with_item.is_async(),
definition,
);
DefinitionKind::WithItem(with_item_definition) => {
self.infer_with_item_definition(with_item_definition, definition);
}
DefinitionKind::MatchPattern(match_pattern) => {
self.infer_match_pattern_definition(
@ -1597,18 +1592,17 @@ impl<'db> TypeInferenceBuilder<'db> {
} = with_statement;
for item in items {
let target = item.optional_vars.as_deref();
if let Some(ast::Expr::Name(name)) = target {
self.infer_definition(name);
if let Some(target) = target {
self.infer_target(target, &item.context_expr, |db, ctx_manager_ty| {
// TODO: `infer_with_statement_definition` reports a diagnostic if `ctx_manager_ty` isn't a context manager
// but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `with not_context_manager as a.x: ...
ctx_manager_ty.enter(db)
});
} else {
// TODO infer definitions in unpacking assignment
// Call into the context expression inference to validate that it evaluates
// to a valid context manager.
let context_expression_ty = if target.is_some() {
self.infer_standalone_expression(&item.context_expr)
} else {
self.infer_expression(&item.context_expr)
};
let context_expression_ty = self.infer_expression(&item.context_expr);
self.infer_context_expression(&item.context_expr, context_expression_ty, *is_async);
self.infer_optional_expression(target);
}
@ -1619,24 +1613,36 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_with_item_definition(
&mut self,
target: &ast::ExprName,
with_item: &ast::WithItem,
is_async: bool,
with_item: &WithItemDefinitionKind<'db>,
definition: Definition<'db>,
) {
self.infer_standalone_expression(&with_item.context_expr);
let context_expr = with_item.context_expr();
let name = with_item.name();
let target_ty = self.infer_context_expression(
&with_item.context_expr,
self.expression_type(&with_item.context_expr),
is_async,
);
let context_expr_ty = self.infer_standalone_expression(context_expr);
self.types.expressions.insert(
target.scoped_expression_id(self.db(), self.scope()),
target_ty,
);
self.add_binding(target.into(), definition, target_ty);
let target_ty = if with_item.is_async() {
todo_type!("async `with` statement")
} else {
match with_item.target() {
TargetKind::Sequence(unpack) => {
let unpacked = infer_unpack_types(self.db(), unpack);
let name_ast_id = name.scoped_expression_id(self.db(), self.scope());
if with_item.is_first() {
self.context.extend(unpacked);
}
unpacked.expression_type(name_ast_id)
}
TargetKind::Name => self.infer_context_expression(
context_expr,
context_expr_ty,
with_item.is_async(),
),
}
};
self.store_expression_type(name, target_ty);
self.add_binding(name.into(), definition, target_ty);
}
/// Infers the type of a context expression (`with expr`) and returns the target's type
@ -1656,120 +1662,12 @@ impl<'db> TypeInferenceBuilder<'db> {
return todo_type!("async `with` statement");
}
let context_manager_ty = context_expression_ty.to_meta_type(self.db());
let enter = context_manager_ty.member(self.db(), "__enter__").symbol;
let exit = context_manager_ty.member(self.db(), "__exit__").symbol;
// TODO: Make use of Protocols when we support it (the manager be assignable to `contextlib.AbstractContextManager`).
match (enter, exit) {
(Symbol::Unbound, Symbol::Unbound) => {
self.context.report_lint(
&INVALID_CONTEXT_MANAGER,
context_expression,
format_args!(
"Object of type `{}` cannot be used with `with` because it doesn't implement `__enter__` and `__exit__`",
context_expression_ty.display(self.db())
),
);
Type::unknown()
}
(Symbol::Unbound, _) => {
self.context.report_lint(
&INVALID_CONTEXT_MANAGER,
context_expression,
format_args!(
"Object of type `{}` cannot be used with `with` because it doesn't implement `__enter__`",
context_expression_ty.display(self.db())
),
);
Type::unknown()
}
(Symbol::Type(enter_ty, enter_boundness), exit) => {
if enter_boundness == Boundness::PossiblyUnbound {
self.context.report_lint(
&INVALID_CONTEXT_MANAGER,
context_expression,
format_args!(
"Object of type `{context_expression}` cannot be used with `with` because the method `__enter__` is possibly unbound",
context_expression = context_expression_ty.display(self.db()),
),
);
}
let target_ty = enter_ty
.try_call(self.db(), &CallArguments::positional([context_expression_ty]))
.map(|outcome| outcome.return_type(self.db()))
.unwrap_or_else(|err| {
// TODO: Use more specific error messages for the different error cases.
// E.g. hint toward the union variant that doesn't correctly implement enter,
// distinguish between a not callable `__enter__` attribute and a wrong signature.
self.context.report_lint(
&INVALID_CONTEXT_MANAGER,
context_expression,
format_args!("
Object of type `{context_expression}` cannot be used with `with` because it does not correctly implement `__enter__`",
context_expression = context_expression_ty.display(self.db()),
),
);
err.fallback_return_type(self.db())
});
match exit {
Symbol::Unbound => {
self.context.report_lint(
&INVALID_CONTEXT_MANAGER,
context_expression,
format_args!(
"Object of type `{}` cannot be used with `with` because it doesn't implement `__exit__`",
context_expression_ty.display(self.db())
),
);
}
Symbol::Type(exit_ty, exit_boundness) => {
// TODO: Use the `exit_ty` to determine if any raised exception is suppressed.
if exit_boundness == Boundness::PossiblyUnbound {
self.context.report_lint(
&INVALID_CONTEXT_MANAGER,
context_expression,
format_args!(
"Object of type `{context_expression}` cannot be used with `with` because the method `__exit__` is possibly unbound",
context_expression = context_expression_ty.display(self.db()),
),
);
}
if exit_ty
.try_call(
self.db(),
&CallArguments::positional([
context_manager_ty,
Type::none(self.db()),
Type::none(self.db()),
Type::none(self.db()),
]),
)
.is_err()
{
// TODO: Use more specific error messages for the different error cases.
// E.g. hint toward the union variant that doesn't correctly implement enter,
// distinguish between a not callable `__exit__` attribute and a wrong signature.
self.context.report_lint(
&INVALID_CONTEXT_MANAGER,
context_expression,
format_args!(
"Object of type `{context_expression}` cannot be used with `with` because it does not correctly implement `__exit__`",
context_expression = context_expression_ty.display(self.db()),
),
);
}
}
}
target_ty
}
}
context_expression_ty
.try_enter(self.db())
.unwrap_or_else(|err| {
err.report_diagnostic(&self.context, context_expression.into());
err.fallback_enter_type(self.db())
})
}
fn infer_except_handler_definition(

View file

@ -42,26 +42,28 @@ impl<'db> Unpacker<'db> {
"Unpacking target must be a list or tuple expression"
);
let mut value_ty = infer_expression_types(self.db(), value.expression())
let value_ty = infer_expression_types(self.db(), value.expression())
.expression_type(value.scoped_expression_id(self.db(), self.scope));
if value.is_assign()
&& self.context.in_stub()
&& value
.expression()
.node_ref(self.db())
.is_ellipsis_literal_expr()
{
value_ty = Type::unknown();
}
if value.is_iterable() {
// If the value is an iterable, then the type that needs to be unpacked is the iterator
// type.
value_ty = value_ty.try_iterate(self.db()).unwrap_or_else(|err| {
let value_ty = match value {
UnpackValue::Assign(expression) => {
if self.context.in_stub()
&& expression.node_ref(self.db()).is_ellipsis_literal_expr()
{
Type::unknown()
} else {
value_ty
}
}
UnpackValue::Iterable(_) => value_ty.try_iterate(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, value.as_any_node_ref(self.db()));
err.fallback_element_type(self.db())
});
}
}),
UnpackValue::ContextManager(_) => value_ty.try_enter(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, value.as_any_node_ref(self.db()));
err.fallback_enter_type(self.db())
}),
};
self.unpack_inner(target, value.as_any_node_ref(self.db()), value_ty);
}

View file

@ -63,25 +63,19 @@ impl<'db> Unpack<'db> {
pub(crate) enum UnpackValue<'db> {
/// An iterable expression like the one in a `for` loop or a comprehension.
Iterable(Expression<'db>),
/// An context manager expression like the one in a `with` statement.
ContextManager(Expression<'db>),
/// An expression that is being assigned to a target.
Assign(Expression<'db>),
}
impl<'db> UnpackValue<'db> {
/// Returns `true` if the value is an iterable expression.
pub(crate) const fn is_iterable(self) -> bool {
matches!(self, UnpackValue::Iterable(_))
}
/// Returns `true` if the value is being assigned to a target.
pub(crate) const fn is_assign(self) -> bool {
matches!(self, UnpackValue::Assign(_))
}
/// Returns the underlying [`Expression`] that is being unpacked.
pub(crate) const fn expression(self) -> Expression<'db> {
match self {
UnpackValue::Assign(expr) | UnpackValue::Iterable(expr) => expr,
UnpackValue::Assign(expr)
| UnpackValue::Iterable(expr)
| UnpackValue::ContextManager(expr) => expr,
}
}