diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 327893821d..ea9b518b5d 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -31,8 +31,10 @@ pub(super) struct SemanticIndexBuilder<'db> { file: File, module: &'db ParsedModule, scope_stack: Vec, - /// the assignment we're currently visiting + /// The assignment we're currently visiting. current_assignment: Option>, + /// Flow states at each `break` in the current loop. + loop_break_states: Vec, // Semantic Index fields scopes: IndexVec, @@ -54,6 +56,7 @@ impl<'db> SemanticIndexBuilder<'db> { module: parsed, scope_stack: Vec::new(), current_assignment: None, + loop_break_states: vec![], scopes: IndexVec::new(), symbol_tables: IndexVec::new(), @@ -125,33 +128,38 @@ impl<'db> SemanticIndexBuilder<'db> { &mut self.symbol_tables[scope_id] } - fn current_use_def_map(&mut self) -> &mut UseDefMapBuilder<'db> { + fn current_use_def_map_mut(&mut self) -> &mut UseDefMapBuilder<'db> { let scope_id = self.current_scope(); &mut self.use_def_maps[scope_id] } + fn current_use_def_map(&self) -> &UseDefMapBuilder<'db> { + let scope_id = self.current_scope(); + &self.use_def_maps[scope_id] + } + fn current_ast_ids(&mut self) -> &mut AstIdsBuilder { let scope_id = self.current_scope(); &mut self.ast_ids[scope_id] } - fn flow_snapshot(&mut self) -> FlowSnapshot { + fn flow_snapshot(&self) -> FlowSnapshot { self.current_use_def_map().snapshot() } fn flow_restore(&mut self, state: FlowSnapshot) { - self.current_use_def_map().restore(state); + self.current_use_def_map_mut().restore(state); } fn flow_merge(&mut self, state: &FlowSnapshot) { - self.current_use_def_map().merge(state); + self.current_use_def_map_mut().merge(state); } fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId { let symbol_table = self.current_symbol_table(); let (symbol_id, added) = symbol_table.add_or_update_symbol(name, flags); if added { - let use_def_map = self.current_use_def_map(); + let use_def_map = self.current_use_def_map_mut(); use_def_map.add_symbol(symbol_id); } symbol_id @@ -176,7 +184,7 @@ impl<'db> SemanticIndexBuilder<'db> { self.definitions_by_node .insert(definition_node.key(), definition); - self.current_use_def_map() + self.current_use_def_map_mut() .record_definition(symbol, definition); definition @@ -416,6 +424,33 @@ where self.flow_merge(&pre_if); } } + ast::Stmt::While(node) => { + self.visit_expr(&node.test); + + let pre_loop = self.flow_snapshot(); + + // Save aside any break states from an outer loop + let saved_break_states = std::mem::take(&mut self.loop_break_states); + self.visit_body(&node.body); + // Get the break states from the body of this loop, and restore the saved outer + // ones. + let break_states = + std::mem::replace(&mut self.loop_break_states, saved_break_states); + + // We may execute the `else` clause without ever executing the body, so merge in + // the pre-loop state before visiting `else`. + self.flow_merge(&pre_loop); + self.visit_body(&node.orelse); + + // Breaking out of a while loop bypasses the `else` clause, so merge in the break + // states after visiting `else`. + for break_state in break_states { + self.flow_merge(&break_state); + } + } + ast::Stmt::Break(_) => { + self.loop_break_states.push(self.flow_snapshot()); + } _ => { walk_stmt(self, stmt); } @@ -460,7 +495,7 @@ where if flags.contains(SymbolFlags::IS_USED) { let use_id = self.current_ast_ids().record_use(expr); - self.current_use_def_map().record_use(symbol, use_id); + self.current_use_def_map_mut().record_use(symbol, use_id); } walk_expr(self, expr); diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs index 79c7ad8a2a..f3e1afe982 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs @@ -194,6 +194,7 @@ pub(super) struct FlowSnapshot { definitions_by_symbol: IndexVec, } +#[derive(Debug)] pub(super) struct UseDefMapBuilder<'db> { /// Definition IDs array for `definitions_by_use` and `definitions_by_symbol` to slice into. all_definitions: Vec>, diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index ab3ebd106f..70071e1cd3 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1481,6 +1481,79 @@ mod tests { Ok(()) } + #[test] + fn while_loop() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + x = 1 + while flag: + x = 2 + ", + )?; + + // body of while loop may or may not run + assert_public_ty(&db, "/src/a.py", "x", "Literal[1, 2]"); + + Ok(()) + } + + #[test] + fn while_else_no_break() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + x = 1 + while flag: + x = 2 + else: + y = x + x = 3 + ", + )?; + + // body of the loop can't break, so we can get else, or body+else + // x must be 3, because else will always run + assert_public_ty(&db, "/src/a.py", "x", "Literal[3]"); + // y can be 1 or 2 because else always runs, and body may or may not run first + assert_public_ty(&db, "/src/a.py", "y", "Literal[1, 2]"); + + Ok(()) + } + + #[test] + fn while_else_may_break() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + x = 1 + y = 0 + while flag: + x = 2 + if flag2: + y = 4 + break + else: + y = x + x = 3 + ", + )?; + + // body may break: we can get just-body (only if we break), just-else, or body+else + assert_public_ty(&db, "/src/a.py", "x", "Literal[2, 3]"); + // if just-body were possible without the break, then 0 would be possible for y + // 1 and 2 both being possible for y shows that we can hit else with or without body + assert_public_ty(&db, "/src/a.py", "y", "Literal[1, 2, 4]"); + + Ok(()) + } + fn first_public_def<'db>(db: &'db TestDb, file: File, name: &str) -> Definition<'db> { let scope = global_scope(db, file); *use_def_map(db, scope)