fix double inference of standalone expressions (#14107)

This commit is contained in:
Micha Reiser 2024-11-05 15:50:31 +01:00 committed by GitHub
parent 05f97bae73
commit 05687285fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 67 additions and 23 deletions

View file

@ -125,6 +125,7 @@ impl<'db> SemanticIndex<'db> {
///
/// Use the Salsa cached [`symbol_table()`] query if you only need the
/// symbol table for a single scope.
#[track_caller]
pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable> {
self.symbol_tables[scope_id].clone()
}
@ -133,15 +134,18 @@ impl<'db> SemanticIndex<'db> {
///
/// Use the Salsa cached [`use_def_map()`] query if you only need the
/// use-def map for a single scope.
#[track_caller]
pub(super) fn use_def_map(&self, scope_id: FileScopeId) -> Arc<UseDefMap> {
self.use_def_maps[scope_id].clone()
}
#[track_caller]
pub(crate) fn ast_ids(&self, scope_id: FileScopeId) -> &AstIds {
&self.ast_ids[scope_id]
}
/// Returns the ID of the `expression`'s enclosing scope.
#[track_caller]
pub(crate) fn expression_scope_id(
&self,
expression: impl Into<ExpressionNodeKey>,
@ -151,11 +155,13 @@ impl<'db> SemanticIndex<'db> {
/// Returns the [`Scope`] of the `expression`'s enclosing scope.
#[allow(unused)]
#[track_caller]
pub(crate) fn expression_scope(&self, expression: impl Into<ExpressionNodeKey>) -> &Scope {
&self.scopes[self.expression_scope_id(expression)]
}
/// Returns the [`Scope`] with the given id.
#[track_caller]
pub(crate) fn scope(&self, id: FileScopeId) -> &Scope {
&self.scopes[id]
}
@ -172,6 +178,7 @@ impl<'db> SemanticIndex<'db> {
/// Returns the parent scope of `scope_id`.
#[allow(unused)]
#[track_caller]
pub(crate) fn parent_scope(&self, scope_id: FileScopeId) -> Option<&Scope> {
Some(&self.scopes[self.parent_scope_id(scope_id)?])
}
@ -195,6 +202,7 @@ impl<'db> SemanticIndex<'db> {
}
/// Returns the [`Definition`] salsa ingredient for `definition_key`.
#[track_caller]
pub(crate) fn definition(
&self,
definition_key: impl Into<DefinitionNodeKey>,
@ -206,6 +214,7 @@ impl<'db> SemanticIndex<'db> {
/// Panics if we have no expression ingredient for that node. We can only call this method for
/// standalone-inferable expressions, which we call `add_standalone_expression` for in
/// [`SemanticIndexBuilder`].
#[track_caller]
pub(crate) fn expression(
&self,
expression_key: impl Into<ExpressionNodeKey>,
@ -213,8 +222,18 @@ impl<'db> SemanticIndex<'db> {
self.expressions_by_node[&expression_key.into()]
}
pub(crate) fn try_expression(
&self,
expression_key: impl Into<ExpressionNodeKey>,
) -> Option<Expression<'db>> {
self.expressions_by_node
.get(&expression_key.into())
.copied()
}
/// Returns the id of the scope that `node` creates. This is different from [`Definition::scope`] which
/// returns the scope in which that definition is defined in.
#[track_caller]
pub(crate) fn node_scope(&self, node: NodeWithScopeRef) -> FileScopeId {
self.scopes_by_node[&node.node_key()]
}

View file

@ -1088,10 +1088,13 @@ where
// AST inspection, so we can't simplify here, need to record test expression for
// later checking)
self.visit_expr(test);
let constraint = self.record_expression_constraint(test);
let pre_if = self.flow_snapshot();
self.visit_expr(body);
let post_body = self.flow_snapshot();
self.flow_restore(pre_if);
self.record_negated_constraint(constraint);
self.visit_expr(orelse);
self.flow_merge(post_body);
}

View file

@ -603,7 +603,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
fn infer_region_expression(&mut self, expression: Expression<'db>) {
self.infer_expression(expression.node_ref(self.db));
self.infer_expression_impl(expression.node_ref(self.db));
}
/// Raise a diagnostic if the given type cannot be divided by zero.
@ -1018,7 +1018,7 @@ impl<'db> TypeInferenceBuilder<'db> {
elif_else_clauses,
} = if_statement;
self.infer_expression(test);
self.infer_standalone_expression(test);
self.infer_body(body);
for clause in elif_else_clauses {
@ -1028,7 +1028,9 @@ impl<'db> TypeInferenceBuilder<'db> {
body,
} = clause;
self.infer_optional_expression(test.as_ref());
if let Some(test) = &test {
self.infer_standalone_expression(test);
}
self.infer_body(body);
}
@ -1088,7 +1090,11 @@ impl<'db> TypeInferenceBuilder<'db> {
// Call into the context expression inference to validate that it evaluates
// to a valid context manager.
let context_expression_ty = self.infer_expression(&item.context_expr);
let context_expression_ty = if target.is_some() {
self.infer_standalone_expression(&item.context_expr)
} else {
self.infer_expression(&item.context_expr)
};
self.infer_context_expression(&item.context_expr, context_expression_ty, *is_async);
self.infer_optional_expression(target);
}
@ -1104,8 +1110,7 @@ impl<'db> TypeInferenceBuilder<'db> {
is_async: bool,
definition: Definition<'db>,
) {
let context_expr = self.index.expression(&with_item.context_expr);
self.extend(infer_expression_types(self.db, context_expr));
self.infer_standalone_expression(&with_item.context_expr);
let target_ty = self.infer_context_expression(
&with_item.context_expr,
@ -1305,9 +1310,7 @@ impl<'db> TypeInferenceBuilder<'db> {
cases,
} = match_statement;
let expression = self.index.expression(subject.as_ref());
let result = infer_expression_types(self.db, expression);
self.extend(result);
self.infer_standalone_expression(subject);
for case in cases {
let ast::MatchCase {
@ -1413,8 +1416,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
_ => {
// TODO: Remove this once we handle all possible assignment targets.
let expression = self.index.expression(value);
self.extend(infer_expression_types(self.db, expression));
self.infer_standalone_expression(value);
self.infer_expression(target);
}
}
@ -1427,9 +1429,7 @@ impl<'db> TypeInferenceBuilder<'db> {
name: &ast::ExprName,
definition: Definition<'db>,
) {
let expression = self.index.expression(value);
let result = infer_expression_types(self.db, expression);
self.extend(result);
self.infer_standalone_expression(value);
let value_ty = self.expression_ty(value);
let name_ast_id = name.scoped_ast_id(self.db, self.scope());
@ -1667,7 +1667,8 @@ impl<'db> TypeInferenceBuilder<'db> {
is_async: _,
} = for_statement;
self.infer_expression(iter);
self.infer_standalone_expression(iter);
// TODO more complex assignment targets
if let ast::Expr::Name(name) = &**target {
self.infer_definition(name);
@ -1685,10 +1686,7 @@ impl<'db> TypeInferenceBuilder<'db> {
is_async: bool,
definition: Definition<'db>,
) {
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db, expression);
self.extend(result);
let iterable_ty = self.expression_ty(iterable);
let iterable_ty = self.infer_standalone_expression(iterable);
let loop_var_value_ty = if is_async {
// TODO(Alex): async iterables/iterators!
@ -1968,7 +1966,25 @@ impl<'db> TypeInferenceBuilder<'db> {
expr.map(|expr| self.infer_annotation_expression(expr))
}
#[track_caller]
fn infer_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
debug_assert_eq!(
self.index.try_expression(expression),
None,
"Calling `self.infer_expression` on a standalone-expression is not allowed because it can lead to double-inference. Use `self.infer_standalone_expression` instead."
);
self.infer_expression_impl(expression)
}
fn infer_standalone_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
let standalone_expression = self.index.expression(expression);
let types = infer_expression_types(self.db, standalone_expression);
self.extend(types);
self.expression_ty(expression)
}
fn infer_expression_impl(&mut self, expression: &ast::Expr) -> Type<'db> {
let ty = match expression {
ast::Expr::NoneLiteral(ast::ExprNoneLiteral { range: _ }) => Type::none(self.db),
ast::Expr::NumberLiteral(literal) => self.infer_number_literal_expression(literal),
@ -2171,7 +2187,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let Some(first_comprehension) = comprehensions_iter.next() else {
unreachable!("Comprehension must contain at least one generator");
};
self.infer_expression(&first_comprehension.iter);
self.infer_standalone_expression(&first_comprehension.iter);
}
fn infer_generator_expression(&mut self, generator: &ast::ExprGenerator) -> Type<'db> {
@ -2296,7 +2312,7 @@ impl<'db> TypeInferenceBuilder<'db> {
} = comprehension;
if !is_first {
self.infer_expression(iter);
self.infer_standalone_expression(iter);
}
// TODO more complex assignment targets
if let ast::Expr::Name(name) = target {
@ -2387,7 +2403,7 @@ impl<'db> TypeInferenceBuilder<'db> {
orelse,
} = if_expression;
let test_ty = self.infer_expression(test);
let test_ty = self.infer_standalone_expression(test);
let body_ty = self.infer_expression(body);
let orelse_ty = self.infer_expression(orelse);
@ -3007,7 +3023,13 @@ impl<'db> TypeInferenceBuilder<'db> {
Self::infer_chained_boolean_types(
self.db,
*op,
values.iter().map(|value| self.infer_expression(value)),
values.iter().enumerate().map(|(index, value)| {
if index == values.len() - 1 {
self.infer_expression(value)
} else {
self.infer_standalone_expression(value)
}
}),
values.len(),
)
}