[red-knot] add cycle-free while-loop control flow (#12413)

Add support for while-loop control flow.

This doesn't yet include general support for terminals and reachability;
that is wider than just while loops and belongs in its own PR.

This also doesn't yet add support for cyclic definitions in loops; that
comes with enough of its own complexity in Salsa that I want to handle
it separately.
This commit is contained in:
Carl Meyer 2024-07-22 14:27:33 -07:00 committed by GitHub
parent dbbe3526ef
commit c7b13bb8fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 117 additions and 8 deletions

View file

@ -31,8 +31,10 @@ pub(super) struct SemanticIndexBuilder<'db> {
file: File, file: File,
module: &'db ParsedModule, module: &'db ParsedModule,
scope_stack: Vec<FileScopeId>, scope_stack: Vec<FileScopeId>,
/// the assignment we're currently visiting /// The assignment we're currently visiting.
current_assignment: Option<CurrentAssignment<'db>>, current_assignment: Option<CurrentAssignment<'db>>,
/// Flow states at each `break` in the current loop.
loop_break_states: Vec<FlowSnapshot>,
// Semantic Index fields // Semantic Index fields
scopes: IndexVec<FileScopeId, Scope>, scopes: IndexVec<FileScopeId, Scope>,
@ -54,6 +56,7 @@ impl<'db> SemanticIndexBuilder<'db> {
module: parsed, module: parsed,
scope_stack: Vec::new(), scope_stack: Vec::new(),
current_assignment: None, current_assignment: None,
loop_break_states: vec![],
scopes: IndexVec::new(), scopes: IndexVec::new(),
symbol_tables: IndexVec::new(), symbol_tables: IndexVec::new(),
@ -125,33 +128,38 @@ impl<'db> SemanticIndexBuilder<'db> {
&mut self.symbol_tables[scope_id] &mut self.symbol_tables[scope_id]
} }
fn current_use_def_map(&mut self) -> &mut UseDefMapBuilder<'db> { fn current_use_def_map_mut(&mut self) -> &mut UseDefMapBuilder<'db> {
let scope_id = self.current_scope(); let scope_id = self.current_scope();
&mut self.use_def_maps[scope_id] &mut self.use_def_maps[scope_id]
} }
fn current_use_def_map(&self) -> &UseDefMapBuilder<'db> {
let scope_id = self.current_scope();
&self.use_def_maps[scope_id]
}
fn current_ast_ids(&mut self) -> &mut AstIdsBuilder { fn current_ast_ids(&mut self) -> &mut AstIdsBuilder {
let scope_id = self.current_scope(); let scope_id = self.current_scope();
&mut self.ast_ids[scope_id] &mut self.ast_ids[scope_id]
} }
fn flow_snapshot(&mut self) -> FlowSnapshot { fn flow_snapshot(&self) -> FlowSnapshot {
self.current_use_def_map().snapshot() self.current_use_def_map().snapshot()
} }
fn flow_restore(&mut self, state: FlowSnapshot) { fn flow_restore(&mut self, state: FlowSnapshot) {
self.current_use_def_map().restore(state); self.current_use_def_map_mut().restore(state);
} }
fn flow_merge(&mut self, state: &FlowSnapshot) { fn flow_merge(&mut self, state: &FlowSnapshot) {
self.current_use_def_map().merge(state); self.current_use_def_map_mut().merge(state);
} }
fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId { fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId {
let symbol_table = self.current_symbol_table(); let symbol_table = self.current_symbol_table();
let (symbol_id, added) = symbol_table.add_or_update_symbol(name, flags); let (symbol_id, added) = symbol_table.add_or_update_symbol(name, flags);
if added { if added {
let use_def_map = self.current_use_def_map(); let use_def_map = self.current_use_def_map_mut();
use_def_map.add_symbol(symbol_id); use_def_map.add_symbol(symbol_id);
} }
symbol_id symbol_id
@ -176,7 +184,7 @@ impl<'db> SemanticIndexBuilder<'db> {
self.definitions_by_node self.definitions_by_node
.insert(definition_node.key(), definition); .insert(definition_node.key(), definition);
self.current_use_def_map() self.current_use_def_map_mut()
.record_definition(symbol, definition); .record_definition(symbol, definition);
definition definition
@ -416,6 +424,33 @@ where
self.flow_merge(&pre_if); self.flow_merge(&pre_if);
} }
} }
ast::Stmt::While(node) => {
self.visit_expr(&node.test);
let pre_loop = self.flow_snapshot();
// Save aside any break states from an outer loop
let saved_break_states = std::mem::take(&mut self.loop_break_states);
self.visit_body(&node.body);
// Get the break states from the body of this loop, and restore the saved outer
// ones.
let break_states =
std::mem::replace(&mut self.loop_break_states, saved_break_states);
// We may execute the `else` clause without ever executing the body, so merge in
// the pre-loop state before visiting `else`.
self.flow_merge(&pre_loop);
self.visit_body(&node.orelse);
// Breaking out of a while loop bypasses the `else` clause, so merge in the break
// states after visiting `else`.
for break_state in break_states {
self.flow_merge(&break_state);
}
}
ast::Stmt::Break(_) => {
self.loop_break_states.push(self.flow_snapshot());
}
_ => { _ => {
walk_stmt(self, stmt); walk_stmt(self, stmt);
} }
@ -460,7 +495,7 @@ where
if flags.contains(SymbolFlags::IS_USED) { if flags.contains(SymbolFlags::IS_USED) {
let use_id = self.current_ast_ids().record_use(expr); let use_id = self.current_ast_ids().record_use(expr);
self.current_use_def_map().record_use(symbol, use_id); self.current_use_def_map_mut().record_use(symbol, use_id);
} }
walk_expr(self, expr); walk_expr(self, expr);

View file

@ -194,6 +194,7 @@ pub(super) struct FlowSnapshot {
definitions_by_symbol: IndexVec<ScopedSymbolId, Definitions>, definitions_by_symbol: IndexVec<ScopedSymbolId, Definitions>,
} }
#[derive(Debug)]
pub(super) struct UseDefMapBuilder<'db> { pub(super) struct UseDefMapBuilder<'db> {
/// Definition IDs array for `definitions_by_use` and `definitions_by_symbol` to slice into. /// Definition IDs array for `definitions_by_use` and `definitions_by_symbol` to slice into.
all_definitions: Vec<Definition<'db>>, all_definitions: Vec<Definition<'db>>,

View file

@ -1481,6 +1481,79 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn while_loop() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
"
x = 1
while flag:
x = 2
",
)?;
// body of while loop may or may not run
assert_public_ty(&db, "/src/a.py", "x", "Literal[1, 2]");
Ok(())
}
#[test]
fn while_else_no_break() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
"
x = 1
while flag:
x = 2
else:
y = x
x = 3
",
)?;
// body of the loop can't break, so we can get else, or body+else
// x must be 3, because else will always run
assert_public_ty(&db, "/src/a.py", "x", "Literal[3]");
// y can be 1 or 2 because else always runs, and body may or may not run first
assert_public_ty(&db, "/src/a.py", "y", "Literal[1, 2]");
Ok(())
}
#[test]
fn while_else_may_break() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
"
x = 1
y = 0
while flag:
x = 2
if flag2:
y = 4
break
else:
y = x
x = 3
",
)?;
// body may break: we can get just-body (only if we break), just-else, or body+else
assert_public_ty(&db, "/src/a.py", "x", "Literal[2, 3]");
// if just-body were possible without the break, then 0 would be possible for y
// 1 and 2 both being possible for y shows that we can hit else with or without body
assert_public_ty(&db, "/src/a.py", "y", "Literal[1, 2, 4]");
Ok(())
}
fn first_public_def<'db>(db: &'db TestDb, file: File, name: &str) -> Definition<'db> { fn first_public_def<'db>(db: &'db TestDb, file: File, name: &str) -> Definition<'db> {
let scope = global_scope(db, file); let scope = global_scope(db, file);
*use_def_map(db, scope) *use_def_map(db, scope)