[ty] Support extending __all__ from an imported module even when the module is not an ExprName node (#17947)

This commit is contained in:
Alex Waygood 2025-05-08 23:54:19 +01:00 committed by GitHub
parent 9b694ada82
commit f51f1f7153
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 147 additions and 31 deletions

View file

@ -251,6 +251,53 @@ from ty_extensions import dunder_all_names
reveal_type(dunder_all_names(exporter)) reveal_type(dunder_all_names(exporter))
``` ```
### Augmenting list with a list or submodule `__all__` (2)
The same again, but the submodule is an attribute expression rather than a name expression:
`exporter/__init__.py`:
```py
```
`exporter/sub.py`:
```py
__all__ = ["foo"]
foo = 42
```
`exporter/sub2.py`:
```py
__all__ = ["bar"]
bar = 56
```
`module.py`:
```py
import exporter.sub
import exporter.sub2
__all__ = []
if True:
__all__.extend(exporter.sub.__all__)
__all__ += exporter.sub2.__all__
```
`main.py`:
```py
import module
from ty_extensions import dunder_all_names
reveal_type(dunder_all_names(module)) # revealed: tuple[Literal["bar"], Literal["foo"]]
```
### Extending with a list or submodule `__all__` ### Extending with a list or submodule `__all__`
`subexporter.py`: `subexporter.py`:

View file

@ -6,11 +6,10 @@ use ruff_python_ast::name::Name;
use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor}; use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor};
use ruff_python_ast::{self as ast}; use ruff_python_ast::{self as ast};
use crate::semantic_index::ast_ids::{HasScopedExpressionId, HasScopedUseId}; use crate::semantic_index::ast_ids::HasScopedExpressionId;
use crate::semantic_index::symbol::ScopeId; use crate::semantic_index::symbol::ScopeId;
use crate::semantic_index::{global_scope, semantic_index, SemanticIndex}; use crate::semantic_index::{global_scope, semantic_index, SemanticIndex};
use crate::symbol::{symbol_from_bindings, Boundness, Symbol}; use crate::types::{infer_expression_types, Truthiness, Type};
use crate::types::{infer_expression_types, Truthiness};
use crate::{resolve_module, Db, ModuleName}; use crate::{resolve_module, Db, ModuleName};
#[allow(clippy::ref_option)] #[allow(clippy::ref_option)]
@ -111,18 +110,8 @@ impl<'db> DunderAllNamesCollector<'db> {
if attr != "__all__" { if attr != "__all__" {
return false; return false;
} }
let Some(name_node) = value.as_name_expr() else { let Type::ModuleLiteral(module_literal) = self.standalone_expression_type(value)
return false; else {
};
let Symbol::Type(ty, Boundness::Bound) = symbol_from_bindings(
self.db,
self.index
.use_def_map(self.scope.file_scope_id(self.db))
.bindings_at_use(name_node.scoped_use_id(self.db, self.scope)),
) else {
return false;
};
let Some(module_literal) = ty.into_module_literal() else {
return false; return false;
}; };
let Some(module_dunder_all_names) = let Some(module_dunder_all_names) =
@ -198,14 +187,21 @@ impl<'db> DunderAllNamesCollector<'db> {
dunder_all_names(self.db, module.file()) dunder_all_names(self.db, module.file())
} }
/// Infer the type of a standalone expression.
///
/// # Panics
///
/// This function panics if `expr` was not marked as a standalone expression during semantic indexing.
fn standalone_expression_type(&self, expr: &ast::Expr) -> Type<'db> {
infer_expression_types(self.db, self.index.expression(expr))
.expression_type(expr.scoped_expression_id(self.db, self.scope))
}
/// Evaluate the given expression and return its truthiness. /// Evaluate the given expression and return its truthiness.
/// ///
/// Returns [`None`] if the expression type doesn't implement `__bool__` correctly. /// Returns [`None`] if the expression type doesn't implement `__bool__` correctly.
fn evaluate_test_expr(&self, expr: &ast::Expr) -> Option<Truthiness> { fn evaluate_test_expr(&self, expr: &ast::Expr) -> Option<Truthiness> {
infer_expression_types(self.db, self.index.expression(expr)) self.standalone_expression_type(expr).try_bool(self.db).ok()
.expression_type(expr.scoped_expression_id(self.db, self.scope))
.try_bool(self.db)
.ok()
} }
/// Add valid names to the set. /// Add valid names to the set.

View file

@ -387,6 +387,14 @@ impl<'db> SemanticIndex<'db> {
.copied() .copied()
} }
pub(crate) fn is_standalone_expression(
&self,
expression_key: impl Into<ExpressionNodeKey>,
) -> bool {
self.expressions_by_node
.contains_key(&expression_key.into())
}
/// Returns the id of the scope that `node` creates. /// Returns the id of the scope that `node` creates.
/// This is different from [`definition::Definition::scope`] which /// This is different from [`definition::Definition::scope`] which
/// returns the scope in which that definition is defined in. /// returns the scope in which that definition is defined in.

View file

@ -1479,23 +1479,37 @@ where
aug_assign @ ast::StmtAugAssign { aug_assign @ ast::StmtAugAssign {
range: _, range: _,
target, target,
op: _, op,
value, value,
}, },
) => { ) => {
debug_assert_eq!(&self.current_assignments, &[]); debug_assert_eq!(&self.current_assignments, &[]);
self.visit_expr(value); self.visit_expr(value);
// See https://docs.python.org/3/library/ast.html#ast.AugAssign match &**target {
if matches!( ast::Expr::Name(ast::ExprName { id, .. })
**target, if id == "__all__" && op.is_add() && self.in_module_scope() =>
ast::Expr::Attribute(_) | ast::Expr::Subscript(_) | ast::Expr::Name(_) {
) { if let ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) =
&**value
{
if attr == "__all__" {
self.add_standalone_expression(value);
}
}
self.push_assignment(aug_assign.into()); self.push_assignment(aug_assign.into());
self.visit_expr(target); self.visit_expr(target);
self.pop_assignment(); self.pop_assignment();
} else { }
ast::Expr::Name(_) | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) => {
self.push_assignment(aug_assign.into());
self.visit_expr(target); self.visit_expr(target);
self.pop_assignment();
}
_ => {
self.visit_expr(target);
}
} }
} }
ast::Stmt::If(node) => { ast::Stmt::If(node) => {
@ -1934,6 +1948,12 @@ where
} }
walk_stmt(self, stmt); walk_stmt(self, stmt);
} }
ast::Stmt::Expr(ast::StmtExpr { value, range: _ }) if self.in_module_scope() => {
if let Some(expr) = dunder_all_extend_argument(value) {
self.add_standalone_expression(expr);
}
self.visit_expr(value);
}
_ => { _ => {
walk_stmt(self, stmt); walk_stmt(self, stmt);
} }
@ -2623,3 +2643,43 @@ impl<'a> Unpackable<'a> {
} }
} }
} }
/// Returns the single argument to `__all__.extend()`, if it is a call to `__all__.extend()`
/// where it looks like the argument might be a `submodule.__all__` expression.
/// Else, returns `None`.
fn dunder_all_extend_argument(value: &ast::Expr) -> Option<&ast::Expr> {
let ast::ExprCall {
func,
arguments:
ast::Arguments {
args,
keywords,
range: _,
},
..
} = value.as_call_expr()?;
let ast::ExprAttribute { value, attr, .. } = func.as_attribute_expr()?;
let ast::ExprName { id, .. } = value.as_name_expr()?;
if id != "__all__" {
return None;
}
if attr != "extend" {
return None;
}
if !keywords.is_empty() {
return None;
}
let [single_argument] = &**args else {
return None;
};
let ast::ExprAttribute { value, attr, .. } = single_argument.as_attribute_expr()?;
(attr == "__all__").then_some(value)
}

View file

@ -5505,7 +5505,12 @@ impl<'db> TypeInferenceBuilder<'db> {
ctx: _, ctx: _,
} = attribute; } = attribute;
let value_type = self.infer_expression(value); let value_type = if self.index.is_standalone_expression(&**value) {
self.infer_standalone_expression(value)
} else {
self.infer_expression(value)
};
let db = self.db(); let db = self.db();
value_type value_type