[ty] Call into type inference more when resolving __all__

This commit is contained in:
Alex Waygood 2025-05-08 12:33:05 +01:00
parent 4f890b2867
commit a9d3e2e253
2 changed files with 47 additions and 37 deletions

View file

@ -6,11 +6,10 @@ use ruff_python_ast::name::Name;
use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor};
use ruff_python_ast::{self as ast};
use crate::semantic_index::ast_ids::{HasScopedExpressionId, HasScopedUseId};
use crate::semantic_index::ast_ids::HasScopedExpressionId;
use crate::semantic_index::symbol::ScopeId;
use crate::semantic_index::{global_scope, semantic_index, SemanticIndex};
use crate::symbol::{symbol_from_bindings, Boundness, Symbol};
use crate::types::{infer_expression_types, Truthiness};
use crate::types::{infer_expression_types, infer_scope_types, Truthiness, Type};
use crate::{resolve_module, Db, ModuleName};
#[allow(clippy::ref_option)]
@ -93,11 +92,33 @@ impl<'db> DunderAllNamesCollector<'db> {
self.origin = Some(origin);
}
/// Extends the current set of names with the names from the given expression which can be
/// either a list of names or a module's `__all__` variable.
fn infer_expression_type(&self, expr: &ast::Expr) -> Type<'db> {
let expression_id = expr.scoped_expression_id(self.db, self.scope);
infer_scope_types(self.db, self.scope).expression_type(expression_id)
}
fn extend_from_tuple(&mut self, expr: &ast::Expr) -> bool {
let Type::Tuple(tuple) = self.infer_expression_type(expr) else {
return false;
};
for element in tuple.elements(self.db) {
let Type::StringLiteral(literal) = element else {
return false;
};
self.names.insert(Name::new(literal.value(self.db)));
}
true
}
/// Extends the current set of names with the names from the given expression.
/// The given expression can be one of the following:
/// - A literal list that consists only of string literals.
/// - An attribute expression where the value is inferred as a module literal
/// and the attribute is `"__all__"`.
/// - Any other expression that is inferred as being a heterogeneous tuple of string literals.
///
/// Returns `true` if the expression is a valid list or module `__all__`, `false` otherwise.
fn extend_from_list_or_module(&mut self, expr: &ast::Expr) -> bool {
fn extend(&mut self, expr: &ast::Expr) -> bool {
match expr {
// `__all__ += [...]`
// `__all__.extend([...])`
@ -107,21 +128,10 @@ impl<'db> DunderAllNamesCollector<'db> {
// `__all__.extend(module.__all__)`
ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => {
if attr != "__all__" {
return false;
return self.extend_from_tuple(expr);
}
let Some(name_node) = value.as_name_expr() else {
return false;
};
let Symbol::Type(ty, Boundness::Bound) = symbol_from_bindings(
self.db,
self.index
.use_def_map(self.scope.file_scope_id(self.db))
.bindings_at_use(name_node.scoped_use_id(self.db, self.scope)),
) else {
return false;
};
let Some(module_literal) = ty.into_module_literal() else {
return false;
let Type::ModuleLiteral(module_literal) = self.infer_expression_type(value) else {
return self.extend_from_tuple(expr);
};
let Some(module_dunder_all_names) =
dunder_all_names(self.db, module_literal.module(self.db).file())
@ -133,7 +143,7 @@ impl<'db> DunderAllNamesCollector<'db> {
true
}
_ => false,
_ => self.extend_from_tuple(expr),
}
}
@ -155,14 +165,14 @@ impl<'db> DunderAllNamesCollector<'db> {
// `__all__.extend([...])`
// `__all__.extend(module.__all__)`
"extend" => {
if !self.extend_from_list_or_module(argument) {
if !self.extend(argument) {
return false;
}
}
// `__all__.append(...)`
"append" => {
let Some(name) = create_name(argument) else {
let Some(name) = self.create_name(argument) else {
return false;
};
self.names.insert(name);
@ -170,7 +180,7 @@ impl<'db> DunderAllNamesCollector<'db> {
// `__all__.remove(...)`
"remove" => {
let Some(name) = create_name(argument) else {
let Some(name) = self.create_name(argument) else {
return false;
};
self.names.remove(&name);
@ -211,7 +221,7 @@ impl<'db> DunderAllNamesCollector<'db> {
/// Returns `false` if any of the names are invalid.
fn add_names(&mut self, exprs: &[ast::Expr]) -> bool {
for expr in exprs {
let Some(name) = create_name(expr) else {
let Some(name) = self.create_name(expr) else {
return false;
};
self.names.insert(name);
@ -233,6 +243,14 @@ impl<'db> DunderAllNamesCollector<'db> {
Some(self.names)
}
}
/// Create and return a [`Name`] from the given expression, [`None`] if it is an invalid expression
/// for a `__all__` element.
fn create_name(&self, expr: &ast::Expr) -> Option<Name> {
self.infer_expression_type(expr)
.into_string_literal()
.map(|literal| Name::new(literal.value(self.db)))
}
}
impl<'db> StatementVisitor<'db> for DunderAllNamesCollector<'db> {
@ -264,9 +282,7 @@ impl<'db> StatementVisitor<'db> for DunderAllNamesCollector<'db> {
} else {
// `from module import __all__`
// `from module import __all__ as __all__`
if name != "__all__"
|| asname.as_ref().is_some_and(|asname| asname != "__all__")
{
if asname.as_ref().unwrap_or(name) != "__all__" {
continue;
}
@ -330,7 +346,7 @@ impl<'db> StatementVisitor<'db> for DunderAllNamesCollector<'db> {
if !is_dunder_all(target) {
return;
}
if !self.extend_from_list_or_module(value) {
if !self.extend(value) {
self.invalid = true;
}
}
@ -462,9 +478,3 @@ enum DunderAllOrigin {
fn is_dunder_all(expr: &ast::Expr) -> bool {
matches!(expr, ast::Expr::Name(ast::ExprName { id, .. }) if id == "__all__")
}
/// Create and return a [`Name`] from the given expression, [`None`] if it is an invalid expression
/// for a `__all__` element.
fn create_name(expr: &ast::Expr) -> Option<Name> {
Some(Name::new(expr.as_string_literal_expr()?.value.to_str()))
}

View file

@ -7908,7 +7908,7 @@ impl<'db> IntersectionType<'db> {
#[salsa::interned(debug)]
pub struct StringLiteralType<'db> {
#[return_ref]
value: Box<str>,
pub(crate) value: Box<str>,
}
impl<'db> StringLiteralType<'db> {
@ -7953,7 +7953,7 @@ impl SliceLiteralType<'_> {
#[salsa::interned(debug)]
pub struct TupleType<'db> {
#[return_ref]
elements: Box<[Type<'db>]>,
pub(crate) elements: Box<[Type<'db>]>,
}
impl<'db> TupleType<'db> {