From 05687285fe0837b583bf6622cc5dcf8fe984700b Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 5 Nov 2024 15:50:31 +0100 Subject: [PATCH] fix double inference of standalone expressions (#14107) --- .../src/semantic_index.rs | 19 ++++++ .../src/semantic_index/builder.rs | 3 + .../src/types/infer.rs | 68 ++++++++++++------- 3 files changed, 67 insertions(+), 23 deletions(-) diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index 083779dc0c..1c57a2085d 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -125,6 +125,7 @@ impl<'db> SemanticIndex<'db> { /// /// Use the Salsa cached [`symbol_table()`] query if you only need the /// symbol table for a single scope. + #[track_caller] pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc { self.symbol_tables[scope_id].clone() } @@ -133,15 +134,18 @@ impl<'db> SemanticIndex<'db> { /// /// Use the Salsa cached [`use_def_map()`] query if you only need the /// use-def map for a single scope. + #[track_caller] pub(super) fn use_def_map(&self, scope_id: FileScopeId) -> Arc { self.use_def_maps[scope_id].clone() } + #[track_caller] pub(crate) fn ast_ids(&self, scope_id: FileScopeId) -> &AstIds { &self.ast_ids[scope_id] } /// Returns the ID of the `expression`'s enclosing scope. + #[track_caller] pub(crate) fn expression_scope_id( &self, expression: impl Into, @@ -151,11 +155,13 @@ impl<'db> SemanticIndex<'db> { /// Returns the [`Scope`] of the `expression`'s enclosing scope. #[allow(unused)] + #[track_caller] pub(crate) fn expression_scope(&self, expression: impl Into) -> &Scope { &self.scopes[self.expression_scope_id(expression)] } /// Returns the [`Scope`] with the given id. + #[track_caller] pub(crate) fn scope(&self, id: FileScopeId) -> &Scope { &self.scopes[id] } @@ -172,6 +178,7 @@ impl<'db> SemanticIndex<'db> { /// Returns the parent scope of `scope_id`. #[allow(unused)] + #[track_caller] pub(crate) fn parent_scope(&self, scope_id: FileScopeId) -> Option<&Scope> { Some(&self.scopes[self.parent_scope_id(scope_id)?]) } @@ -195,6 +202,7 @@ impl<'db> SemanticIndex<'db> { } /// Returns the [`Definition`] salsa ingredient for `definition_key`. + #[track_caller] pub(crate) fn definition( &self, definition_key: impl Into, @@ -206,6 +214,7 @@ impl<'db> SemanticIndex<'db> { /// Panics if we have no expression ingredient for that node. We can only call this method for /// standalone-inferable expressions, which we call `add_standalone_expression` for in /// [`SemanticIndexBuilder`]. + #[track_caller] pub(crate) fn expression( &self, expression_key: impl Into, @@ -213,8 +222,18 @@ impl<'db> SemanticIndex<'db> { self.expressions_by_node[&expression_key.into()] } + pub(crate) fn try_expression( + &self, + expression_key: impl Into, + ) -> Option> { + self.expressions_by_node + .get(&expression_key.into()) + .copied() + } + /// Returns the id of the scope that `node` creates. This is different from [`Definition::scope`] which /// returns the scope in which that definition is defined in. + #[track_caller] pub(crate) fn node_scope(&self, node: NodeWithScopeRef) -> FileScopeId { self.scopes_by_node[&node.node_key()] } diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index c4534f688c..811e3ecdb8 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -1088,10 +1088,13 @@ where // AST inspection, so we can't simplify here, need to record test expression for // later checking) self.visit_expr(test); + let constraint = self.record_expression_constraint(test); let pre_if = self.flow_snapshot(); self.visit_expr(body); let post_body = self.flow_snapshot(); self.flow_restore(pre_if); + + self.record_negated_constraint(constraint); self.visit_expr(orelse); self.flow_merge(post_body); } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 1e18e8c51d..fd4b21758b 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -603,7 +603,7 @@ impl<'db> TypeInferenceBuilder<'db> { } fn infer_region_expression(&mut self, expression: Expression<'db>) { - self.infer_expression(expression.node_ref(self.db)); + self.infer_expression_impl(expression.node_ref(self.db)); } /// Raise a diagnostic if the given type cannot be divided by zero. @@ -1018,7 +1018,7 @@ impl<'db> TypeInferenceBuilder<'db> { elif_else_clauses, } = if_statement; - self.infer_expression(test); + self.infer_standalone_expression(test); self.infer_body(body); for clause in elif_else_clauses { @@ -1028,7 +1028,9 @@ impl<'db> TypeInferenceBuilder<'db> { body, } = clause; - self.infer_optional_expression(test.as_ref()); + if let Some(test) = &test { + self.infer_standalone_expression(test); + } self.infer_body(body); } @@ -1088,7 +1090,11 @@ impl<'db> TypeInferenceBuilder<'db> { // Call into the context expression inference to validate that it evaluates // to a valid context manager. - let context_expression_ty = self.infer_expression(&item.context_expr); + let context_expression_ty = if target.is_some() { + self.infer_standalone_expression(&item.context_expr) + } else { + self.infer_expression(&item.context_expr) + }; self.infer_context_expression(&item.context_expr, context_expression_ty, *is_async); self.infer_optional_expression(target); } @@ -1104,8 +1110,7 @@ impl<'db> TypeInferenceBuilder<'db> { is_async: bool, definition: Definition<'db>, ) { - let context_expr = self.index.expression(&with_item.context_expr); - self.extend(infer_expression_types(self.db, context_expr)); + self.infer_standalone_expression(&with_item.context_expr); let target_ty = self.infer_context_expression( &with_item.context_expr, @@ -1305,9 +1310,7 @@ impl<'db> TypeInferenceBuilder<'db> { cases, } = match_statement; - let expression = self.index.expression(subject.as_ref()); - let result = infer_expression_types(self.db, expression); - self.extend(result); + self.infer_standalone_expression(subject); for case in cases { let ast::MatchCase { @@ -1413,8 +1416,7 @@ impl<'db> TypeInferenceBuilder<'db> { } _ => { // TODO: Remove this once we handle all possible assignment targets. - let expression = self.index.expression(value); - self.extend(infer_expression_types(self.db, expression)); + self.infer_standalone_expression(value); self.infer_expression(target); } } @@ -1427,9 +1429,7 @@ impl<'db> TypeInferenceBuilder<'db> { name: &ast::ExprName, definition: Definition<'db>, ) { - let expression = self.index.expression(value); - let result = infer_expression_types(self.db, expression); - self.extend(result); + self.infer_standalone_expression(value); let value_ty = self.expression_ty(value); let name_ast_id = name.scoped_ast_id(self.db, self.scope()); @@ -1667,7 +1667,8 @@ impl<'db> TypeInferenceBuilder<'db> { is_async: _, } = for_statement; - self.infer_expression(iter); + self.infer_standalone_expression(iter); + // TODO more complex assignment targets if let ast::Expr::Name(name) = &**target { self.infer_definition(name); @@ -1685,10 +1686,7 @@ impl<'db> TypeInferenceBuilder<'db> { is_async: bool, definition: Definition<'db>, ) { - let expression = self.index.expression(iterable); - let result = infer_expression_types(self.db, expression); - self.extend(result); - let iterable_ty = self.expression_ty(iterable); + let iterable_ty = self.infer_standalone_expression(iterable); let loop_var_value_ty = if is_async { // TODO(Alex): async iterables/iterators! @@ -1968,7 +1966,25 @@ impl<'db> TypeInferenceBuilder<'db> { expr.map(|expr| self.infer_annotation_expression(expr)) } + #[track_caller] fn infer_expression(&mut self, expression: &ast::Expr) -> Type<'db> { + debug_assert_eq!( + self.index.try_expression(expression), + None, + "Calling `self.infer_expression` on a standalone-expression is not allowed because it can lead to double-inference. Use `self.infer_standalone_expression` instead." + ); + + self.infer_expression_impl(expression) + } + + fn infer_standalone_expression(&mut self, expression: &ast::Expr) -> Type<'db> { + let standalone_expression = self.index.expression(expression); + let types = infer_expression_types(self.db, standalone_expression); + self.extend(types); + self.expression_ty(expression) + } + + fn infer_expression_impl(&mut self, expression: &ast::Expr) -> Type<'db> { let ty = match expression { ast::Expr::NoneLiteral(ast::ExprNoneLiteral { range: _ }) => Type::none(self.db), ast::Expr::NumberLiteral(literal) => self.infer_number_literal_expression(literal), @@ -2171,7 +2187,7 @@ impl<'db> TypeInferenceBuilder<'db> { let Some(first_comprehension) = comprehensions_iter.next() else { unreachable!("Comprehension must contain at least one generator"); }; - self.infer_expression(&first_comprehension.iter); + self.infer_standalone_expression(&first_comprehension.iter); } fn infer_generator_expression(&mut self, generator: &ast::ExprGenerator) -> Type<'db> { @@ -2296,7 +2312,7 @@ impl<'db> TypeInferenceBuilder<'db> { } = comprehension; if !is_first { - self.infer_expression(iter); + self.infer_standalone_expression(iter); } // TODO more complex assignment targets if let ast::Expr::Name(name) = target { @@ -2387,7 +2403,7 @@ impl<'db> TypeInferenceBuilder<'db> { orelse, } = if_expression; - let test_ty = self.infer_expression(test); + let test_ty = self.infer_standalone_expression(test); let body_ty = self.infer_expression(body); let orelse_ty = self.infer_expression(orelse); @@ -3007,7 +3023,13 @@ impl<'db> TypeInferenceBuilder<'db> { Self::infer_chained_boolean_types( self.db, *op, - values.iter().map(|value| self.infer_expression(value)), + values.iter().enumerate().map(|(index, value)| { + if index == values.len() - 1 { + self.infer_expression(value) + } else { + self.infer_standalone_expression(value) + } + }), values.len(), ) }