[red-knot] Add type inference for loop variables inside comprehension scopes (#13251)

This commit is contained in:
Alex Waygood 2024-09-09 16:22:01 -04:00 committed by GitHub
parent ac720cd705
commit 6f53aaf931
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 371 additions and 15 deletions

View file

@ -689,6 +689,7 @@ where
iterable: &node.iter,
target: name_node,
first,
is_async: node.is_async,
},
);
}

View file

@ -167,6 +167,7 @@ pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
pub(crate) iterable: &'a ast::Expr,
pub(crate) target: &'a ast::ExprName,
pub(crate) first: bool,
pub(crate) is_async: bool,
}
#[derive(Copy, Clone, Debug)]
@ -227,10 +228,12 @@ impl DefinitionNodeRef<'_> {
iterable,
target,
first,
is_async,
}) => DefinitionKind::Comprehension(ComprehensionDefinitionKind {
iterable: AstNodeRef::new(parsed.clone(), iterable),
target: AstNodeRef::new(parsed, target),
first,
is_async,
}),
DefinitionNodeRef::Parameter(parameter) => match parameter {
ast::AnyParameterRef::Variadic(parameter) => {
@ -337,6 +340,7 @@ pub struct ComprehensionDefinitionKind {
iterable: AstNodeRef<ast::Expr>,
target: AstNodeRef<ast::ExprName>,
first: bool,
is_async: bool,
}
impl ComprehensionDefinitionKind {
@ -351,6 +355,10 @@ impl ComprehensionDefinitionKind {
pub(crate) fn is_first(&self) -> bool {
self.first
}
pub(crate) fn is_async(&self) -> bool {
self.is_async
}
}
#[derive(Clone, Debug)]

View file

@ -406,6 +406,7 @@ impl<'db> TypeInferenceBuilder<'db> {
comprehension.iterable(),
comprehension.target(),
comprehension.is_first(),
comprehension.is_async(),
definition,
);
}
@ -1444,7 +1445,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let expr_id = expression.scoped_ast_id(self.db, self.scope);
let previous = self.types.expressions.insert(expr_id, ty);
assert!(previous.is_none());
assert_eq!(previous, None);
ty
}
@ -1747,22 +1748,38 @@ impl<'db> TypeInferenceBuilder<'db> {
iterable: &ast::Expr,
target: &ast::ExprName,
is_first: bool,
is_async: bool,
definition: Definition<'db>,
) {
if !is_first {
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db, expression);
self.extend(result);
let _iterable_ty = self
.types
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));
}
// TODO(dhruvmanila): The iter type for the first comprehension is coming from the
// enclosing scope.
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db, expression);
// TODO(dhruvmanila): The target type should be inferred based on the iter type instead,
// similar to how it's done in `infer_for_statement_definition`.
let target_ty = Type::Unknown;
// Two things are different if it's the first comprehension:
// (1) We must lookup the `ScopedExpressionId` of the iterable expression in the outer scope,
// because that's the scope we visit it in in the semantic index builder
// (2) We must *not* call `self.extend()` on the result of the type inference,
// because `ScopedExpressionId`s are only meaningful within their own scope, so
// we'd add types for random wrong expressions in the current scope
let iterable_ty = if is_first {
let lookup_scope = self
.index
.parent_scope_id(self.scope.file_scope_id(self.db))
.expect("A comprehension should never be the top-level scope")
.to_scope_id(self.db, self.file);
result.expression_ty(iterable.scoped_ast_id(self.db, lookup_scope))
} else {
self.extend(result);
result.expression_ty(iterable.scoped_ast_id(self.db, self.scope))
};
let target_ty = if is_async {
// TODO: async iterables/iterators! -- Alex
Type::Unknown
} else {
iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(iterable.into(), self)
};
self.types
.expressions
@ -4191,7 +4208,6 @@ mod tests {
",
)?;
// TODO(Alex) async iterables/iterators!
assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown");
Ok(())
@ -4326,6 +4342,337 @@ mod tests {
Ok(())
}
#[test]
fn basic_comprehension() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
def foo():
[x for y in IterableOfIterables() for x in y]
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
class IteratorOfIterables:
def __next__(self) -> IntIterable:
return IntIterable()
class IterableOfIterables:
def __iter__(self) -> IteratorOfIterables:
return IteratorOfIterables()
",
)?;
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "x", "int");
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "y", "IntIterable");
Ok(())
}
#[test]
fn comprehension_inside_comprehension() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
def foo():
[[x for x in iter1] for y in iter2]
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
iter1 = IntIterable()
iter2 = IntIterable()
",
)?;
assert_scope_ty(
&db,
"src/a.py",
&["foo", "<listcomp>", "<listcomp>"],
"x",
"int",
);
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "y", "int");
Ok(())
}
#[test]
fn inner_comprehension_referencing_outer_comprehension() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
def foo():
[[x for x in y] for y in z]
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
class IteratorOfIterables:
def __next__(self) -> IntIterable:
return IntIterable()
class IterableOfIterables:
def __iter__(self) -> IteratorOfIterables:
return IteratorOfIterables()
z = IterableOfIterables()
",
)?;
assert_scope_ty(
&db,
"src/a.py",
&["foo", "<listcomp>", "<listcomp>"],
"x",
"int",
);
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "y", "IntIterable");
Ok(())
}
#[test]
fn comprehension_with_unbound_iter() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented("src/a.py", "[z for z in x]")?;
assert_scope_ty(&db, "src/a.py", &["<listcomp>"], "x", "Unbound");
// Iterating over an `Unbound` yields `Unknown`:
assert_scope_ty(&db, "src/a.py", &["<listcomp>"], "z", "Unknown");
// TODO: not the greatest error message in the world! --Alex
assert_file_diagnostics(
&db,
"src/a.py",
&["Object of type 'Unbound' is not iterable"],
);
Ok(())
}
#[test]
fn comprehension_with_not_iterable_iter_in_second_comprehension() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
def foo():
[z for x in IntIterable() for z in x]
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
",
)?;
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "x", "int");
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "z", "Unknown");
assert_file_diagnostics(&db, "src/a.py", &["Object of type 'int' is not iterable"]);
Ok(())
}
#[test]
fn dict_comprehension_variable_key() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
def foo():
{x: 0 for x in IntIterable()}
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
",
)?;
assert_scope_ty(&db, "src/a.py", &["foo", "<dictcomp>"], "x", "int");
assert_file_diagnostics(&db, "src/a.py", &[]);
Ok(())
}
#[test]
fn dict_comprehension_variable_value() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
def foo():
{0: x for x in IntIterable()}
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
",
)?;
assert_scope_ty(&db, "src/a.py", &["foo", "<dictcomp>"], "x", "int");
assert_file_diagnostics(&db, "src/a.py", &[]);
Ok(())
}
#[test]
fn comprehension_with_missing_in_keyword() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
def foo():
[z for z IntIterable()]
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
",
)?;
// We'll emit a diagnostic separately for invalid syntax,
// but it's reasonably clear here what they *meant* to write,
// so we'll still infer the correct type:
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "z", "int");
Ok(())
}
#[test]
fn comprehension_with_missing_iter() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
def foo():
[z for in IntIterable()]
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
",
)?;
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "z", "Unbound");
// (There is a diagnostic for invalid syntax that's emitted, but it's not listed by `assert_file_diagnostics`)
assert_file_diagnostics(&db, "src/a.py", &[]);
Ok(())
}
#[test]
fn comprehension_with_missing_for() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented("src/a.py", "[z for z in]")?;
assert_scope_ty(&db, "src/a.py", &["<listcomp>"], "z", "Unknown");
Ok(())
}
#[test]
fn comprehension_with_missing_in_keyword_and_missing_iter() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented("src/a.py", "[z for z]")?;
assert_scope_ty(&db, "src/a.py", &["<listcomp>"], "z", "Unknown");
Ok(())
}
/// This tests that we understand that `async` comprehensions
/// do not work according to the synchronous iteration protocol
#[test]
fn invalid_async_comprehension() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
async def foo():
[x async for x in Iterable()]
class Iterator:
def __next__(self) -> int:
return 42
class Iterable:
def __iter__(self) -> Iterator:
return Iterator()
",
)?;
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "x", "Unknown");
Ok(())
}
#[test]
fn basic_async_comprehension() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
async def foo():
[x async for x in AsyncIterable()]
class AsyncIterator:
async def __anext__(self) -> int:
return 42
class AsyncIterable:
def __aiter__(self) -> AsyncIterator:
return AsyncIterator()
",
)?;
// TODO async iterables/iterators! --Alex
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "x", "Unknown");
Ok(())
}
#[test]
fn invalid_iterable() {
let mut db = setup_db();