mirror of
https://github.com/astral-sh/ruff.git
synced 2025-10-02 22:55:08 +00:00
[red-knot] Add type inference for loop variables inside comprehension scopes (#13251)
This commit is contained in:
parent
ac720cd705
commit
6f53aaf931
3 changed files with 371 additions and 15 deletions
|
@ -689,6 +689,7 @@ where
|
|||
iterable: &node.iter,
|
||||
target: name_node,
|
||||
first,
|
||||
is_async: node.is_async,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
|
|
@ -167,6 +167,7 @@ pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
|
|||
pub(crate) iterable: &'a ast::Expr,
|
||||
pub(crate) target: &'a ast::ExprName,
|
||||
pub(crate) first: bool,
|
||||
pub(crate) is_async: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
|
@ -227,10 +228,12 @@ impl DefinitionNodeRef<'_> {
|
|||
iterable,
|
||||
target,
|
||||
first,
|
||||
is_async,
|
||||
}) => DefinitionKind::Comprehension(ComprehensionDefinitionKind {
|
||||
iterable: AstNodeRef::new(parsed.clone(), iterable),
|
||||
target: AstNodeRef::new(parsed, target),
|
||||
first,
|
||||
is_async,
|
||||
}),
|
||||
DefinitionNodeRef::Parameter(parameter) => match parameter {
|
||||
ast::AnyParameterRef::Variadic(parameter) => {
|
||||
|
@ -337,6 +340,7 @@ pub struct ComprehensionDefinitionKind {
|
|||
iterable: AstNodeRef<ast::Expr>,
|
||||
target: AstNodeRef<ast::ExprName>,
|
||||
first: bool,
|
||||
is_async: bool,
|
||||
}
|
||||
|
||||
impl ComprehensionDefinitionKind {
|
||||
|
@ -351,6 +355,10 @@ impl ComprehensionDefinitionKind {
|
|||
pub(crate) fn is_first(&self) -> bool {
|
||||
self.first
|
||||
}
|
||||
|
||||
pub(crate) fn is_async(&self) -> bool {
|
||||
self.is_async
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
|
|
@ -406,6 +406,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
comprehension.iterable(),
|
||||
comprehension.target(),
|
||||
comprehension.is_first(),
|
||||
comprehension.is_async(),
|
||||
definition,
|
||||
);
|
||||
}
|
||||
|
@ -1444,7 +1445,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
|
||||
let expr_id = expression.scoped_ast_id(self.db, self.scope);
|
||||
let previous = self.types.expressions.insert(expr_id, ty);
|
||||
assert!(previous.is_none());
|
||||
assert_eq!(previous, None);
|
||||
|
||||
ty
|
||||
}
|
||||
|
@ -1747,22 +1748,38 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
iterable: &ast::Expr,
|
||||
target: &ast::ExprName,
|
||||
is_first: bool,
|
||||
is_async: bool,
|
||||
definition: Definition<'db>,
|
||||
) {
|
||||
if !is_first {
|
||||
let expression = self.index.expression(iterable);
|
||||
let result = infer_expression_types(self.db, expression);
|
||||
self.extend(result);
|
||||
let _iterable_ty = self
|
||||
.types
|
||||
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));
|
||||
}
|
||||
// TODO(dhruvmanila): The iter type for the first comprehension is coming from the
|
||||
// enclosing scope.
|
||||
let expression = self.index.expression(iterable);
|
||||
let result = infer_expression_types(self.db, expression);
|
||||
|
||||
// TODO(dhruvmanila): The target type should be inferred based on the iter type instead,
|
||||
// similar to how it's done in `infer_for_statement_definition`.
|
||||
let target_ty = Type::Unknown;
|
||||
// Two things are different if it's the first comprehension:
|
||||
// (1) We must lookup the `ScopedExpressionId` of the iterable expression in the outer scope,
|
||||
// because that's the scope we visit it in in the semantic index builder
|
||||
// (2) We must *not* call `self.extend()` on the result of the type inference,
|
||||
// because `ScopedExpressionId`s are only meaningful within their own scope, so
|
||||
// we'd add types for random wrong expressions in the current scope
|
||||
let iterable_ty = if is_first {
|
||||
let lookup_scope = self
|
||||
.index
|
||||
.parent_scope_id(self.scope.file_scope_id(self.db))
|
||||
.expect("A comprehension should never be the top-level scope")
|
||||
.to_scope_id(self.db, self.file);
|
||||
result.expression_ty(iterable.scoped_ast_id(self.db, lookup_scope))
|
||||
} else {
|
||||
self.extend(result);
|
||||
result.expression_ty(iterable.scoped_ast_id(self.db, self.scope))
|
||||
};
|
||||
|
||||
let target_ty = if is_async {
|
||||
// TODO: async iterables/iterators! -- Alex
|
||||
Type::Unknown
|
||||
} else {
|
||||
iterable_ty
|
||||
.iterate(self.db)
|
||||
.unwrap_with_diagnostic(iterable.into(), self)
|
||||
};
|
||||
|
||||
self.types
|
||||
.expressions
|
||||
|
@ -4191,7 +4208,6 @@ mod tests {
|
|||
",
|
||||
)?;
|
||||
|
||||
// TODO(Alex) async iterables/iterators!
|
||||
assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown");
|
||||
|
||||
Ok(())
|
||||
|
@ -4326,6 +4342,337 @@ mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn basic_comprehension() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
"
|
||||
def foo():
|
||||
[x for y in IterableOfIterables() for x in y]
|
||||
|
||||
class IntIterator:
|
||||
def __next__(self) -> int:
|
||||
return 42
|
||||
|
||||
class IntIterable:
|
||||
def __iter__(self) -> IntIterator:
|
||||
return IntIterator()
|
||||
|
||||
class IteratorOfIterables:
|
||||
def __next__(self) -> IntIterable:
|
||||
return IntIterable()
|
||||
|
||||
class IterableOfIterables:
|
||||
def __iter__(self) -> IteratorOfIterables:
|
||||
return IteratorOfIterables()
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "x", "int");
|
||||
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "y", "IntIterable");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comprehension_inside_comprehension() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
"
|
||||
def foo():
|
||||
[[x for x in iter1] for y in iter2]
|
||||
|
||||
class IntIterator:
|
||||
def __next__(self) -> int:
|
||||
return 42
|
||||
|
||||
class IntIterable:
|
||||
def __iter__(self) -> IntIterator:
|
||||
return IntIterator()
|
||||
|
||||
iter1 = IntIterable()
|
||||
iter2 = IntIterable()
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_scope_ty(
|
||||
&db,
|
||||
"src/a.py",
|
||||
&["foo", "<listcomp>", "<listcomp>"],
|
||||
"x",
|
||||
"int",
|
||||
);
|
||||
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "y", "int");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inner_comprehension_referencing_outer_comprehension() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
"
|
||||
def foo():
|
||||
[[x for x in y] for y in z]
|
||||
|
||||
class IntIterator:
|
||||
def __next__(self) -> int:
|
||||
return 42
|
||||
|
||||
class IntIterable:
|
||||
def __iter__(self) -> IntIterator:
|
||||
return IntIterator()
|
||||
|
||||
class IteratorOfIterables:
|
||||
def __next__(self) -> IntIterable:
|
||||
return IntIterable()
|
||||
|
||||
class IterableOfIterables:
|
||||
def __iter__(self) -> IteratorOfIterables:
|
||||
return IteratorOfIterables()
|
||||
|
||||
z = IterableOfIterables()
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_scope_ty(
|
||||
&db,
|
||||
"src/a.py",
|
||||
&["foo", "<listcomp>", "<listcomp>"],
|
||||
"x",
|
||||
"int",
|
||||
);
|
||||
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "y", "IntIterable");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comprehension_with_unbound_iter() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented("src/a.py", "[z for z in x]")?;
|
||||
|
||||
assert_scope_ty(&db, "src/a.py", &["<listcomp>"], "x", "Unbound");
|
||||
|
||||
// Iterating over an `Unbound` yields `Unknown`:
|
||||
assert_scope_ty(&db, "src/a.py", &["<listcomp>"], "z", "Unknown");
|
||||
|
||||
// TODO: not the greatest error message in the world! --Alex
|
||||
assert_file_diagnostics(
|
||||
&db,
|
||||
"src/a.py",
|
||||
&["Object of type 'Unbound' is not iterable"],
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comprehension_with_not_iterable_iter_in_second_comprehension() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
"
|
||||
def foo():
|
||||
[z for x in IntIterable() for z in x]
|
||||
|
||||
class IntIterator:
|
||||
def __next__(self) -> int:
|
||||
return 42
|
||||
|
||||
class IntIterable:
|
||||
def __iter__(self) -> IntIterator:
|
||||
return IntIterator()
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "x", "int");
|
||||
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "z", "Unknown");
|
||||
assert_file_diagnostics(&db, "src/a.py", &["Object of type 'int' is not iterable"]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dict_comprehension_variable_key() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
"
|
||||
def foo():
|
||||
{x: 0 for x in IntIterable()}
|
||||
|
||||
class IntIterator:
|
||||
def __next__(self) -> int:
|
||||
return 42
|
||||
|
||||
class IntIterable:
|
||||
def __iter__(self) -> IntIterator:
|
||||
return IntIterator()
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_scope_ty(&db, "src/a.py", &["foo", "<dictcomp>"], "x", "int");
|
||||
assert_file_diagnostics(&db, "src/a.py", &[]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dict_comprehension_variable_value() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
"
|
||||
def foo():
|
||||
{0: x for x in IntIterable()}
|
||||
|
||||
class IntIterator:
|
||||
def __next__(self) -> int:
|
||||
return 42
|
||||
|
||||
class IntIterable:
|
||||
def __iter__(self) -> IntIterator:
|
||||
return IntIterator()
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_scope_ty(&db, "src/a.py", &["foo", "<dictcomp>"], "x", "int");
|
||||
assert_file_diagnostics(&db, "src/a.py", &[]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comprehension_with_missing_in_keyword() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
"
|
||||
def foo():
|
||||
[z for z IntIterable()]
|
||||
|
||||
class IntIterator:
|
||||
def __next__(self) -> int:
|
||||
return 42
|
||||
|
||||
class IntIterable:
|
||||
def __iter__(self) -> IntIterator:
|
||||
return IntIterator()
|
||||
",
|
||||
)?;
|
||||
|
||||
// We'll emit a diagnostic separately for invalid syntax,
|
||||
// but it's reasonably clear here what they *meant* to write,
|
||||
// so we'll still infer the correct type:
|
||||
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "z", "int");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comprehension_with_missing_iter() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
"
|
||||
def foo():
|
||||
[z for in IntIterable()]
|
||||
|
||||
class IntIterator:
|
||||
def __next__(self) -> int:
|
||||
return 42
|
||||
|
||||
class IntIterable:
|
||||
def __iter__(self) -> IntIterator:
|
||||
return IntIterator()
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "z", "Unbound");
|
||||
|
||||
// (There is a diagnostic for invalid syntax that's emitted, but it's not listed by `assert_file_diagnostics`)
|
||||
assert_file_diagnostics(&db, "src/a.py", &[]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comprehension_with_missing_for() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
db.write_dedented("src/a.py", "[z for z in]")?;
|
||||
assert_scope_ty(&db, "src/a.py", &["<listcomp>"], "z", "Unknown");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comprehension_with_missing_in_keyword_and_missing_iter() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
db.write_dedented("src/a.py", "[z for z]")?;
|
||||
assert_scope_ty(&db, "src/a.py", &["<listcomp>"], "z", "Unknown");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This tests that we understand that `async` comprehensions
|
||||
/// do not work according to the synchronous iteration protocol
|
||||
#[test]
|
||||
fn invalid_async_comprehension() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
"
|
||||
async def foo():
|
||||
[x async for x in Iterable()]
|
||||
class Iterator:
|
||||
def __next__(self) -> int:
|
||||
return 42
|
||||
class Iterable:
|
||||
def __iter__(self) -> Iterator:
|
||||
return Iterator()
|
||||
",
|
||||
)?;
|
||||
|
||||
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "x", "Unknown");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn basic_async_comprehension() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
"
|
||||
async def foo():
|
||||
[x async for x in AsyncIterable()]
|
||||
class AsyncIterator:
|
||||
async def __anext__(self) -> int:
|
||||
return 42
|
||||
class AsyncIterable:
|
||||
def __aiter__(self) -> AsyncIterator:
|
||||
return AsyncIterator()
|
||||
",
|
||||
)?;
|
||||
|
||||
// TODO async iterables/iterators! --Alex
|
||||
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "x", "Unknown");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_iterable() {
|
||||
let mut db = setup_db();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue