mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-27 20:42:10 +00:00
[red-knot] Add symbols for for
loop variables (#13075)
## Summary This PR adds symbols introduced by `for` loops to red-knot: - `x` in `for x in range(10): pass` - `x` and `y` in `for x, y in d.items(): pass` - `a`, `b`, `c` and `d` in `for [((a,), b), (c, d)] in foo: pass` ## Test Plan Several tests added, and the assertion in the benchmarks has been updated. --------- Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
parent
99df859e20
commit
d19fd1b91c
5 changed files with 181 additions and 18 deletions
|
@ -1095,4 +1095,56 @@ match subject:
|
|||
vec!["subject", "a", "b", "c", "d", "f", "e", "h", "g", "Foo", "i", "j", "k", "l"]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn for_loops_single_assignment() {
|
||||
let TestCase { db, file } = test_case("for x in a: pass");
|
||||
let scope = global_scope(&db, file);
|
||||
let global_table = symbol_table(&db, scope);
|
||||
|
||||
assert_eq!(&names(&global_table), &["a", "x"]);
|
||||
|
||||
let use_def = use_def_map(&db, scope);
|
||||
let definition = use_def
|
||||
.first_public_definition(global_table.symbol_id_by_name("x").unwrap())
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(definition.node(&db), DefinitionKind::For(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn for_loops_simple_unpacking() {
|
||||
let TestCase { db, file } = test_case("for (x, y) in a: pass");
|
||||
let scope = global_scope(&db, file);
|
||||
let global_table = symbol_table(&db, scope);
|
||||
|
||||
assert_eq!(&names(&global_table), &["a", "x", "y"]);
|
||||
|
||||
let use_def = use_def_map(&db, scope);
|
||||
let x_definition = use_def
|
||||
.first_public_definition(global_table.symbol_id_by_name("x").unwrap())
|
||||
.unwrap();
|
||||
let y_definition = use_def
|
||||
.first_public_definition(global_table.symbol_id_by_name("y").unwrap())
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(x_definition.node(&db), DefinitionKind::For(_)));
|
||||
assert!(matches!(y_definition.node(&db), DefinitionKind::For(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn for_loops_complex_unpacking() {
|
||||
let TestCase { db, file } = test_case("for [((a,) b), (c, d)] in e: pass");
|
||||
let scope = global_scope(&db, file);
|
||||
let global_table = symbol_table(&db, scope);
|
||||
|
||||
assert_eq!(&names(&global_table), &["e", "a", "b", "c", "d"]);
|
||||
|
||||
let use_def = use_def_map(&db, scope);
|
||||
let definition = use_def
|
||||
.first_public_definition(global_table.symbol_id_by_name("a").unwrap())
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(definition.node(&db), DefinitionKind::For(_)));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
|
|||
use crate::semantic_index::ast_ids::AstIdsBuilder;
|
||||
use crate::semantic_index::definition::{
|
||||
AssignmentDefinitionNodeRef, ComprehensionDefinitionNodeRef, Definition, DefinitionNodeKey,
|
||||
DefinitionNodeRef, ImportFromDefinitionNodeRef,
|
||||
DefinitionNodeRef, ForStmtDefinitionNodeRef, ImportFromDefinitionNodeRef,
|
||||
};
|
||||
use crate::semantic_index::expression::Expression;
|
||||
use crate::semantic_index::symbol::{
|
||||
|
@ -578,6 +578,27 @@ where
|
|||
ast::Stmt::Break(_) => {
|
||||
self.loop_break_states.push(self.flow_snapshot());
|
||||
}
|
||||
|
||||
ast::Stmt::For(
|
||||
for_stmt @ ast::StmtFor {
|
||||
range: _,
|
||||
is_async: _,
|
||||
target,
|
||||
iter,
|
||||
body,
|
||||
orelse,
|
||||
},
|
||||
) => {
|
||||
// TODO add control flow similar to `ast::Stmt::While` above
|
||||
self.add_standalone_expression(iter);
|
||||
self.visit_expr(iter);
|
||||
debug_assert!(self.current_assignment.is_none());
|
||||
self.current_assignment = Some(for_stmt.into());
|
||||
self.visit_expr(target);
|
||||
self.current_assignment = None;
|
||||
self.visit_body(body);
|
||||
self.visit_body(orelse);
|
||||
}
|
||||
_ => {
|
||||
walk_stmt(self, stmt);
|
||||
}
|
||||
|
@ -624,6 +645,15 @@ where
|
|||
Some(CurrentAssignment::AugAssign(aug_assign)) => {
|
||||
self.add_definition(symbol, aug_assign);
|
||||
}
|
||||
Some(CurrentAssignment::For(node)) => {
|
||||
self.add_definition(
|
||||
symbol,
|
||||
ForStmtDefinitionNodeRef {
|
||||
iterable: &node.iter,
|
||||
target: name_node,
|
||||
},
|
||||
);
|
||||
}
|
||||
Some(CurrentAssignment::Named(named)) => {
|
||||
// TODO(dhruvmanila): If the current scope is a comprehension, then the
|
||||
// named expression is implicitly nonlocal. This is yet to be
|
||||
|
@ -796,6 +826,7 @@ enum CurrentAssignment<'a> {
|
|||
Assign(&'a ast::StmtAssign),
|
||||
AnnAssign(&'a ast::StmtAnnAssign),
|
||||
AugAssign(&'a ast::StmtAugAssign),
|
||||
For(&'a ast::StmtFor),
|
||||
Named(&'a ast::ExprNamed),
|
||||
Comprehension {
|
||||
node: &'a ast::Comprehension,
|
||||
|
@ -822,6 +853,12 @@ impl<'a> From<&'a ast::StmtAugAssign> for CurrentAssignment<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a ast::StmtFor> for CurrentAssignment<'a> {
|
||||
fn from(value: &'a ast::StmtFor) -> Self {
|
||||
Self::For(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
|
||||
fn from(value: &'a ast::ExprNamed) -> Self {
|
||||
Self::Named(value)
|
||||
|
|
|
@ -39,6 +39,7 @@ impl<'db> Definition<'db> {
|
|||
pub(crate) enum DefinitionNodeRef<'a> {
|
||||
Import(&'a ast::Alias),
|
||||
ImportFrom(ImportFromDefinitionNodeRef<'a>),
|
||||
For(ForStmtDefinitionNodeRef<'a>),
|
||||
Function(&'a ast::StmtFunctionDef),
|
||||
Class(&'a ast::StmtClassDef),
|
||||
NamedExpression(&'a ast::ExprNamed),
|
||||
|
@ -92,6 +93,12 @@ impl<'a> From<ImportFromDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> From<ForStmtDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
|
||||
fn from(value: ForStmtDefinitionNodeRef<'a>) -> Self {
|
||||
Self::For(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<AssignmentDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
|
||||
fn from(node_ref: AssignmentDefinitionNodeRef<'a>) -> Self {
|
||||
Self::Assignment(node_ref)
|
||||
|
@ -134,6 +141,12 @@ pub(crate) struct WithItemDefinitionNodeRef<'a> {
|
|||
pub(crate) target: &'a ast::ExprName,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub(crate) struct ForStmtDefinitionNodeRef<'a> {
|
||||
pub(crate) iterable: &'a ast::Expr,
|
||||
pub(crate) target: &'a ast::ExprName,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
|
||||
pub(crate) node: &'a ast::Comprehension,
|
||||
|
@ -174,6 +187,12 @@ impl DefinitionNodeRef<'_> {
|
|||
DefinitionNodeRef::AugmentedAssignment(augmented_assignment) => {
|
||||
DefinitionKind::AugmentedAssignment(AstNodeRef::new(parsed, augmented_assignment))
|
||||
}
|
||||
DefinitionNodeRef::For(ForStmtDefinitionNodeRef { iterable, target }) => {
|
||||
DefinitionKind::For(ForStmtDefinitionKind {
|
||||
iterable: AstNodeRef::new(parsed.clone(), iterable),
|
||||
target: AstNodeRef::new(parsed, target),
|
||||
})
|
||||
}
|
||||
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { node, first }) => {
|
||||
DefinitionKind::Comprehension(ComprehensionDefinitionKind {
|
||||
node: AstNodeRef::new(parsed, node),
|
||||
|
@ -212,6 +231,10 @@ impl DefinitionNodeRef<'_> {
|
|||
}) => target.into(),
|
||||
Self::AnnotatedAssignment(node) => node.into(),
|
||||
Self::AugmentedAssignment(node) => node.into(),
|
||||
Self::For(ForStmtDefinitionNodeRef {
|
||||
iterable: _,
|
||||
target,
|
||||
}) => target.into(),
|
||||
Self::Comprehension(ComprehensionDefinitionNodeRef { node, first: _ }) => node.into(),
|
||||
Self::Parameter(node) => match node {
|
||||
ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
|
||||
|
@ -232,6 +255,7 @@ pub enum DefinitionKind {
|
|||
Assignment(AssignmentDefinitionKind),
|
||||
AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>),
|
||||
AugmentedAssignment(AstNodeRef<ast::StmtAugAssign>),
|
||||
For(ForStmtDefinitionKind),
|
||||
Comprehension(ComprehensionDefinitionKind),
|
||||
Parameter(AstNodeRef<ast::Parameter>),
|
||||
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
|
||||
|
@ -302,6 +326,22 @@ impl WithItemDefinitionKind {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ForStmtDefinitionKind {
|
||||
iterable: AstNodeRef<ast::Expr>,
|
||||
target: AstNodeRef<ast::ExprName>,
|
||||
}
|
||||
|
||||
impl ForStmtDefinitionKind {
|
||||
pub(crate) fn iterable(&self) -> &ast::Expr {
|
||||
self.iterable.node()
|
||||
}
|
||||
|
||||
pub(crate) fn target(&self) -> &ast::ExprName {
|
||||
self.target.node()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
|
||||
pub(crate) struct DefinitionNodeKey(NodeKey);
|
||||
|
||||
|
@ -347,6 +387,12 @@ impl From<&ast::StmtAugAssign> for DefinitionNodeKey {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<&ast::StmtFor> for DefinitionNodeKey {
|
||||
fn from(value: &ast::StmtFor) -> Self {
|
||||
Self(NodeKey::from_node(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ast::Comprehension> for DefinitionNodeKey {
|
||||
fn from(node: &ast::Comprehension) -> Self {
|
||||
Self(NodeKey::from_node(node))
|
||||
|
|
|
@ -138,7 +138,6 @@ pub(crate) struct TypeInference<'db> {
|
|||
}
|
||||
|
||||
impl<'db> TypeInference<'db> {
|
||||
#[allow(unused)]
|
||||
pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> {
|
||||
self.expressions[&expression]
|
||||
}
|
||||
|
@ -317,6 +316,13 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
DefinitionKind::AugmentedAssignment(augmented_assignment) => {
|
||||
self.infer_augment_assignment_definition(augmented_assignment.node(), definition);
|
||||
}
|
||||
DefinitionKind::For(for_statement_definition) => {
|
||||
self.infer_for_statement_definition(
|
||||
for_statement_definition.target(),
|
||||
for_statement_definition.iterable(),
|
||||
definition,
|
||||
);
|
||||
}
|
||||
DefinitionKind::NamedExpression(named_expression) => {
|
||||
self.infer_named_expression_definition(named_expression.node(), definition);
|
||||
}
|
||||
|
@ -865,11 +871,48 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
} = for_statement;
|
||||
|
||||
self.infer_expression(iter);
|
||||
self.infer_expression(target);
|
||||
// TODO more complex assignment targets
|
||||
if let ast::Expr::Name(name) = &**target {
|
||||
self.infer_definition(name);
|
||||
} else {
|
||||
self.infer_expression(target);
|
||||
}
|
||||
self.infer_body(body);
|
||||
self.infer_body(orelse);
|
||||
}
|
||||
|
||||
fn infer_for_statement_definition(
|
||||
&mut self,
|
||||
target: &ast::ExprName,
|
||||
iterable: &ast::Expr,
|
||||
definition: Definition<'db>,
|
||||
) {
|
||||
let expression = self.index.expression(iterable);
|
||||
let result = infer_expression_types(self.db, expression);
|
||||
self.extend(result);
|
||||
let iterable_ty = self
|
||||
.types
|
||||
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));
|
||||
|
||||
// TODO(Alex): only a valid iterable if the *type* of `iterable_ty` has an `__iter__`
|
||||
// member (dunders are never looked up on an instance)
|
||||
let _dunder_iter_ty = iterable_ty.member(self.db, &ast::name::Name::from("__iter__"));
|
||||
|
||||
// TODO(Alex):
|
||||
// - infer the return type of the `__iter__` method, which gives us the iterator
|
||||
// - lookup the `__next__` method on the iterator
|
||||
// - infer the return type of the iterator's `__next__` method,
|
||||
// which gives us the type of the variable being bound here
|
||||
// (...or the type of the object being unpacked into multiple definitions, if it's something like
|
||||
// `for k, v in d.items(): ...`)
|
||||
let loop_var_value_ty = Type::Unknown;
|
||||
|
||||
self.types
|
||||
.expressions
|
||||
.insert(target.scoped_ast_id(self.db, self.scope), loop_var_value_ty);
|
||||
self.types.definitions.insert(definition, loop_var_value_ty);
|
||||
}
|
||||
|
||||
fn infer_while_statement(&mut self, while_statement: &ast::StmtWhile) {
|
||||
let ast::StmtWhile {
|
||||
range: _,
|
||||
|
|
|
@ -40,23 +40,8 @@ static EXPECTED_DIAGNOSTICS: &[&str] = &[
|
|||
"Use double quotes for strings",
|
||||
"Use double quotes for strings",
|
||||
"Use double quotes for strings",
|
||||
"/src/tomllib/_parser.py:153:22: Name 'key' used when not defined.",
|
||||
"/src/tomllib/_parser.py:153:27: Name 'flag' used when not defined.",
|
||||
"/src/tomllib/_parser.py:159:16: Name 'k' used when not defined.",
|
||||
"/src/tomllib/_parser.py:161:25: Name 'k' used when not defined.",
|
||||
"/src/tomllib/_parser.py:168:16: Name 'k' used when not defined.",
|
||||
"/src/tomllib/_parser.py:169:22: Name 'k' used when not defined.",
|
||||
"/src/tomllib/_parser.py:170:25: Name 'k' used when not defined.",
|
||||
"/src/tomllib/_parser.py:180:16: Name 'k' used when not defined.",
|
||||
"/src/tomllib/_parser.py:182:31: Name 'k' used when not defined.",
|
||||
"/src/tomllib/_parser.py:206:16: Name 'k' used when not defined.",
|
||||
"/src/tomllib/_parser.py:207:22: Name 'k' used when not defined.",
|
||||
"/src/tomllib/_parser.py:208:25: Name 'k' used when not defined.",
|
||||
"/src/tomllib/_parser.py:330:32: Name 'header' used when not defined.",
|
||||
"/src/tomllib/_parser.py:330:41: Name 'key' used when not defined.",
|
||||
"/src/tomllib/_parser.py:333:26: Name 'cont_key' used when not defined.",
|
||||
"/src/tomllib/_parser.py:334:71: Name 'cont_key' used when not defined.",
|
||||
"/src/tomllib/_parser.py:337:31: Name 'cont_key' used when not defined.",
|
||||
"/src/tomllib/_parser.py:628:75: Name 'e' used when not defined.",
|
||||
"/src/tomllib/_parser.py:686:23: Name 'parse_float' used when not defined.",
|
||||
];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue