[ty] basic narrowing on attribute and subscript expressions (#17643)

## Summary

This PR closes astral-sh/ty#164.

This PR introduces a basic type narrowing mechanism for
attribute/subscript expressions.
Member accesses, int literal subscripts, string literal subscripts are
supported (same as mypy and pyright).

## Test Plan

New test cases are added to `mdtest/narrow/complex_target.md`.

---------

Co-authored-by: David Peter <mail@david-peter.de>
This commit is contained in:
Shunsuke Shibayama 2025-06-17 18:07:46 +09:00 committed by GitHub
parent 390918e790
commit 342b2665db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 739 additions and 327 deletions

View file

@ -751,7 +751,8 @@ reveal_type(C.pure_class_variable) # revealed: Unknown
# and the assignment is properly attributed to the class method.
# error: [invalid-attribute-access] "Cannot assign to instance attribute `pure_class_variable` from the class object `<class 'C'>`"
C.pure_class_variable = "overwritten on class"
# TODO: should be no error
# error: [unresolved-attribute] "Attribute `pure_class_variable` can only be accessed on instances, not on the class object `<class 'C'>` itself."
reveal_type(C.pure_class_variable) # revealed: Literal["overwritten on class"]
c_instance = C()

View file

@ -87,6 +87,12 @@ class _:
reveal_type(a.y) # revealed: Unknown | None
reveal_type(a.z) # revealed: Unknown | None
a = A()
# error: [unresolved-attribute]
a.dynamically_added = 0
# error: [unresolved-attribute]
reveal_type(a.dynamically_added) # revealed: Literal[0]
# error: [unresolved-reference]
does.nt.exist = 0
# error: [unresolved-reference]

View file

@ -0,0 +1,224 @@
# Narrowing for complex targets (attribute expressions, subscripts)
We support type narrowing for attributes and subscripts.
## Attribute narrowing
### Basic
```py
from ty_extensions import Unknown
class C:
x: int | None = None
c = C()
reveal_type(c.x) # revealed: int | None
if c.x is not None:
reveal_type(c.x) # revealed: int
else:
reveal_type(c.x) # revealed: None
if c.x is not None:
c.x = None
reveal_type(c.x) # revealed: None
c = C()
if c.x is None:
c.x = 1
reveal_type(c.x) # revealed: int
class _:
reveal_type(c.x) # revealed: int
c = C()
class _:
if c.x is None:
c.x = 1
reveal_type(c.x) # revealed: int
# TODO: should be `int`
reveal_type(c.x) # revealed: int | None
class D:
x = None
def unknown() -> Unknown:
return 1
d = D()
reveal_type(d.x) # revealed: Unknown | None
d.x = 1
reveal_type(d.x) # revealed: Literal[1]
d.x = unknown()
reveal_type(d.x) # revealed: Unknown
```
Narrowing can be "reset" by assigning to the attribute:
```py
c = C()
if c.x is None:
reveal_type(c.x) # revealed: None
c.x = 1
reveal_type(c.x) # revealed: Literal[1]
c.x = None
reveal_type(c.x) # revealed: None
reveal_type(c.x) # revealed: int | None
```
Narrowing can also be "reset" by assigning to the object:
```py
c = C()
if c.x is None:
reveal_type(c.x) # revealed: None
c = C()
reveal_type(c.x) # revealed: int | None
reveal_type(c.x) # revealed: int | None
```
### Multiple predicates
```py
class C:
value: str | None
def foo(c: C):
if c.value and len(c.value):
reveal_type(c.value) # revealed: str & ~AlwaysFalsy
# error: [invalid-argument-type] "Argument to function `len` is incorrect: Expected `Sized`, found `str | None`"
if len(c.value) and c.value:
reveal_type(c.value) # revealed: str & ~AlwaysFalsy
if c.value is None or not len(c.value):
reveal_type(c.value) # revealed: str | None
else: # c.value is not None and len(c.value)
# TODO: should be # `str & ~AlwaysFalsy`
reveal_type(c.value) # revealed: str
```
### Generic class
```toml
[environment]
python-version = "3.12"
```
```py
class C[T]:
x: T
y: T
def __init__(self, x: T):
self.x = x
self.y = x
def f(a: int | None):
c = C(a)
reveal_type(c.x) # revealed: int | None
reveal_type(c.y) # revealed: int | None
if c.x is not None:
reveal_type(c.x) # revealed: int
# In this case, it may seem like we can narrow it down to `int`,
# but different values may be reassigned to `x` and `y` in another place.
reveal_type(c.y) # revealed: int | None
def g[T](c: C[T]):
reveal_type(c.x) # revealed: T
reveal_type(c.y) # revealed: T
reveal_type(c) # revealed: C[T]
if isinstance(c.x, int):
reveal_type(c.x) # revealed: T & int
reveal_type(c.y) # revealed: T
reveal_type(c) # revealed: C[T]
if isinstance(c.x, int) and isinstance(c.y, int):
reveal_type(c.x) # revealed: T & int
reveal_type(c.y) # revealed: T & int
# TODO: Probably better if inferred as `C[T & int]` (mypy and pyright don't support this)
reveal_type(c) # revealed: C[T]
```
### With intermediate scopes
```py
class C:
def __init__(self):
self.x: int | None = None
self.y: int | None = None
c = C()
reveal_type(c.x) # revealed: int | None
if c.x is not None:
reveal_type(c.x) # revealed: int
reveal_type(c.y) # revealed: int | None
if c.x is not None:
def _():
reveal_type(c.x) # revealed: Unknown | int | None
def _():
if c.x is not None:
reveal_type(c.x) # revealed: (Unknown & ~None) | int
```
## Subscript narrowing
### Number subscript
```py
def _(t1: tuple[int | None, int | None], t2: tuple[int, int] | tuple[None, None]):
if t1[0] is not None:
reveal_type(t1[0]) # revealed: int
reveal_type(t1[1]) # revealed: int | None
n = 0
if t1[n] is not None:
# Non-literal subscript narrowing are currently not supported, as well as mypy, pyright
reveal_type(t1[0]) # revealed: int | None
reveal_type(t1[n]) # revealed: int | None
reveal_type(t1[1]) # revealed: int | None
if t2[0] is not None:
# TODO: should be int
reveal_type(t2[0]) # revealed: Unknown & ~None
# TODO: should be int
reveal_type(t2[1]) # revealed: Unknown
```
### String subscript
```py
def _(d: dict[str, str | None]):
if d["a"] is not None:
reveal_type(d["a"]) # revealed: str
reveal_type(d["b"]) # revealed: str | None
```
## Combined attribute and subscript narrowing
```py
class C:
def __init__(self):
self.x: tuple[int | None, int | None] = (None, None)
class D:
def __init__(self):
self.c: tuple[C] | None = None
d = D()
if d.c is not None and d.c[0].x[0] is not None:
reveal_type(d.c[0].x[0]) # revealed: int
```

View file

@ -135,7 +135,7 @@ class _:
class _3:
reveal_type(a) # revealed: A
# TODO: should be `D | None`
reveal_type(a.b.c1.d) # revealed: D
reveal_type(a.b.c1.d) # revealed: Unknown
a.b.c1 = C()
a.b.c1.d = D()
@ -173,12 +173,10 @@ def f(x: str | None):
reveal_type(g) # revealed: str
if a.x is not None:
# TODO(#17643): should be `Unknown | str`
reveal_type(a.x) # revealed: Unknown | str | None
reveal_type(a.x) # revealed: (Unknown & ~None) | str
if l[0] is not None:
# TODO(#17643): should be `str`
reveal_type(l[0]) # revealed: str | None
reveal_type(l[0]) # revealed: str
class C:
if x is not None:
@ -191,12 +189,10 @@ def f(x: str | None):
reveal_type(g) # revealed: str
if a.x is not None:
# TODO(#17643): should be `Unknown | str`
reveal_type(a.x) # revealed: Unknown | str | None
reveal_type(a.x) # revealed: (Unknown & ~None) | str
if l[0] is not None:
# TODO(#17643): should be `str`
reveal_type(l[0]) # revealed: str | None
reveal_type(l[0]) # revealed: str
# TODO: should be str
# This could be fixed if we supported narrowing with if clauses in comprehensions.
@ -241,22 +237,18 @@ def f(x: str | None):
reveal_type(a.x) # revealed: Unknown | str | None
class D:
# TODO(#17643): should be `Unknown | str`
reveal_type(a.x) # revealed: Unknown | str | None
reveal_type(a.x) # revealed: (Unknown & ~None) | str
# TODO(#17643): should be `Unknown | str`
[reveal_type(a.x) for _ in range(1)] # revealed: Unknown | str | None
[reveal_type(a.x) for _ in range(1)] # revealed: (Unknown & ~None) | str
if l[0] is not None:
def _():
reveal_type(l[0]) # revealed: str | None
class D:
# TODO(#17643): should be `str`
reveal_type(l[0]) # revealed: str | None
reveal_type(l[0]) # revealed: str
# TODO(#17643): should be `str`
[reveal_type(l[0]) for _ in range(1)] # revealed: str | None
[reveal_type(l[0]) for _ in range(1)] # revealed: str
```
### Narrowing constraints introduced in multiple scopes
@ -299,24 +291,20 @@ def f(x: str | Literal[1] | None):
if a.x is not None:
def _():
if a.x != 1:
# TODO(#17643): should be `Unknown | str | None`
reveal_type(a.x) # revealed: Unknown | str | Literal[1] | None
reveal_type(a.x) # revealed: (Unknown & ~Literal[1]) | str | None
class D:
if a.x != 1:
# TODO(#17643): should be `Unknown | str`
reveal_type(a.x) # revealed: Unknown | str | Literal[1] | None
reveal_type(a.x) # revealed: (Unknown & ~Literal[1] & ~None) | str
if l[0] is not None:
def _():
if l[0] != 1:
# TODO(#17643): should be `str | None`
reveal_type(l[0]) # revealed: str | Literal[1] | None
reveal_type(l[0]) # revealed: str | None
class D:
if l[0] != 1:
# TODO(#17643): should be `str`
reveal_type(l[0]) # revealed: str | Literal[1] | None
reveal_type(l[0]) # revealed: str
```
### Narrowing constraints with bindings in class scope, and nested scopes

View file

@ -220,8 +220,7 @@ def _(a: tuple[str, int] | tuple[int, str], c: C[Any]):
if reveal_type(is_int(a[0])): # revealed: TypeIs[int @ a[0]]
# TODO: Should be `tuple[int, str]`
reveal_type(a) # revealed: tuple[str, int] | tuple[int, str]
# TODO: Should be `int`
reveal_type(a[0]) # revealed: Unknown
reveal_type(a[0]) # revealed: Unknown & int
# TODO: Should be `TypeGuard[str @ c.v]`
if reveal_type(guard_str(c.v)): # revealed: @Todo(`TypeGuard[]` special form)
@ -231,8 +230,7 @@ def _(a: tuple[str, int] | tuple[int, str], c: C[Any]):
if reveal_type(is_int(c.v)): # revealed: TypeIs[int @ c.v]
reveal_type(c) # revealed: C[Any]
# TODO: Should be `int`
reveal_type(c.v) # revealed: Any
reveal_type(c.v) # revealed: Any & int
```
Indirect usage is supported within the same scope:

View file

@ -17,4 +17,5 @@ setuptools # vendors packaging, see above
spack # slow, success, but mypy-primer hangs processing the output
spark # too many iterations
steam.py # hangs (single threaded)
tornado # bad use-def map (https://github.com/astral-sh/ty/issues/365)
xarray # too many iterations

View file

@ -110,7 +110,6 @@ stone
strawberry
streamlit
sympy
tornado
trio
twine
typeshed-stats

View file

@ -661,6 +661,7 @@ fn place_by_id<'db>(
// See mdtest/known_constants.md#user-defined-type_checking for details.
let is_considered_non_modifiable = place_table(db, scope)
.place_expr(place_id)
.expr
.is_name_and(|name| matches!(name, "__slots__" | "TYPE_CHECKING"));
if scope.file(db).is_stub(db.upcast()) {
@ -1124,8 +1125,8 @@ mod implicit_globals {
module_type_symbol_table
.places()
.filter(|symbol| symbol.is_declared() && symbol.is_name())
.map(semantic_index::place::PlaceExpr::expect_name)
.filter(|place| place.is_declared() && place.is_name())
.map(semantic_index::place::PlaceExprWithFlags::expect_name)
.filter(|symbol_name| {
!matches!(&***symbol_name, "__dict__" | "__getattr__" | "__init__")
})

View file

@ -37,8 +37,8 @@ mod reachability_constraints;
mod use_def;
pub(crate) use self::use_def::{
BindingWithConstraints, BindingWithConstraintsIterator, DeclarationWithConstraint,
DeclarationsIterator,
ApplicableConstraints, BindingWithConstraints, BindingWithConstraintsIterator,
DeclarationWithConstraint, DeclarationsIterator,
};
type PlaceSet = hashbrown::HashMap<ScopedPlaceId, (), FxBuildHasher>;

View file

@ -33,7 +33,7 @@ use crate::semantic_index::definition::{
use crate::semantic_index::expression::{Expression, ExpressionKind};
use crate::semantic_index::place::{
FileScopeId, NodeWithScopeKey, NodeWithScopeKind, NodeWithScopeRef, PlaceExpr,
PlaceTableBuilder, Scope, ScopeId, ScopeKind, ScopedPlaceId,
PlaceExprWithFlags, PlaceTableBuilder, Scope, ScopeId, ScopeKind, ScopedPlaceId,
};
use crate::semantic_index::predicate::{
PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, ScopedPredicateId,
@ -295,6 +295,15 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
// If the scope that we just popped off is an eager scope, we need to "lock" our view of
// which bindings reach each of the uses in the scope. Loop through each enclosing scope,
// looking for any that bind each place.
// TODO: Bindings in eager nested scopes also need to be recorded. For example:
// ```python
// class C:
// x: int | None = None
// c = C()
// class _:
// c.x = 1
// reveal_type(c.x) # revealed: Literal[1]
// ```
for enclosing_scope_info in self.scope_stack.iter().rev() {
let enclosing_scope_id = enclosing_scope_info.file_scope_id;
let enclosing_scope_kind = self.scopes[enclosing_scope_id].kind();
@ -306,7 +315,8 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
// it may refer to the enclosing scope bindings
// so we also need to snapshot the bindings of the enclosing scope.
let Some(enclosing_place_id) = enclosing_place_table.place_id_by_expr(nested_place)
let Some(enclosing_place_id) =
enclosing_place_table.place_id_by_expr(&nested_place.expr)
else {
continue;
};
@ -388,7 +398,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
/// Add a place to the place table and the use-def map.
/// Return the [`ScopedPlaceId`] that uniquely identifies the place in both.
fn add_place(&mut self, place_expr: PlaceExpr) -> ScopedPlaceId {
fn add_place(&mut self, place_expr: PlaceExprWithFlags) -> ScopedPlaceId {
let (place_id, added) = self.current_place_table().add_place(place_expr);
if added {
self.current_use_def_map_mut().add_place(place_id);
@ -1863,7 +1873,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
walk_stmt(self, stmt);
for target in targets {
if let Ok(target) = PlaceExpr::try_from(target) {
let place_id = self.add_place(target);
let place_id = self.add_place(PlaceExprWithFlags::new(target));
self.current_place_table().mark_place_used(place_id);
self.delete_binding(place_id);
}
@ -1898,7 +1908,8 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
ast::Expr::Name(ast::ExprName { ctx, .. })
| ast::Expr::Attribute(ast::ExprAttribute { ctx, .. })
| ast::Expr::Subscript(ast::ExprSubscript { ctx, .. }) => {
if let Ok(mut place_expr) = PlaceExpr::try_from(expr) {
if let Ok(place_expr) = PlaceExpr::try_from(expr) {
let mut place_expr = PlaceExprWithFlags::new(place_expr);
if self.is_method_of_class().is_some()
&& place_expr.is_instance_attribute_candidate()
{
@ -1906,7 +1917,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
// i.e. typically `self` or `cls`.
let accessed_object_refers_to_first_parameter = self
.current_first_parameter_name
.is_some_and(|fst| place_expr.root_name() == fst);
.is_some_and(|fst| place_expr.expr.root_name() == fst);
if accessed_object_refers_to_first_parameter && place_expr.is_member() {
place_expr.mark_instance_attribute();

View file

@ -30,6 +30,7 @@
use crate::list::{List, ListBuilder, ListSetReverseIterator, ListStorage};
use crate::semantic_index::ast_ids::ScopedUseId;
use crate::semantic_index::place::FileScopeId;
use crate::semantic_index::predicate::ScopedPredicateId;
/// A narrowing constraint associated with a live binding.
@ -42,6 +43,7 @@ pub(crate) type ScopedNarrowingConstraint = List<ScopedNarrowingConstraintPredic
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum ConstraintKey {
NarrowingConstraint(ScopedNarrowingConstraint),
EagerNestedScope(FileScopeId),
UseId(ScopedUseId),
}

View file

@ -18,7 +18,7 @@ use crate::node_key::NodeKey;
use crate::semantic_index::reachability_constraints::ScopedReachabilityConstraintId;
use crate::semantic_index::{PlaceSet, SemanticIndex, semantic_index};
#[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum PlaceExprSubSegment {
/// A member access, e.g. `.y` in `x.y`
Member(ast::name::Name),
@ -38,13 +38,10 @@ impl PlaceExprSubSegment {
}
/// An expression that can be the target of a `Definition`.
/// If you want to perform a comparison based on the equality of segments (without including
/// flags), use [`PlaceSegments`].
#[derive(Eq, PartialEq, Debug)]
pub struct PlaceExpr {
root_name: Name,
sub_segments: SmallVec<[PlaceExprSubSegment; 1]>,
flags: PlaceFlags,
}
impl std::fmt::Display for PlaceExpr {
@ -151,23 +148,27 @@ impl TryFrom<&ast::Expr> for PlaceExpr {
}
}
impl TryFrom<ast::ExprRef<'_>> for PlaceExpr {
type Error = ();
fn try_from(expr: ast::ExprRef) -> Result<Self, ()> {
match expr {
ast::ExprRef::Name(name) => Ok(PlaceExpr::name(name.id.clone())),
ast::ExprRef::Attribute(attr) => PlaceExpr::try_from(attr),
ast::ExprRef::Subscript(subscript) => PlaceExpr::try_from(subscript),
_ => Err(()),
}
}
}
impl PlaceExpr {
pub(super) fn name(name: Name) -> Self {
pub(crate) fn name(name: Name) -> Self {
Self {
root_name: name,
sub_segments: smallvec![],
flags: PlaceFlags::empty(),
}
}
fn insert_flags(&mut self, flags: PlaceFlags) {
self.flags.insert(flags);
}
pub(super) fn mark_instance_attribute(&mut self) {
self.flags.insert(PlaceFlags::IS_INSTANCE_ATTRIBUTE);
}
pub(crate) fn root_name(&self) -> &Name {
&self.root_name
}
@ -191,6 +192,66 @@ impl PlaceExpr {
&self.root_name
}
/// Is the place just a name?
pub fn is_name(&self) -> bool {
self.sub_segments.is_empty()
}
pub fn is_name_and(&self, f: impl FnOnce(&str) -> bool) -> bool {
self.is_name() && f(&self.root_name)
}
/// Does the place expression have the form `<object>.member`?
pub fn is_member(&self) -> bool {
self.sub_segments
.last()
.is_some_and(|last| last.as_member().is_some())
}
fn root_exprs(&self) -> RootExprs<'_> {
RootExprs {
expr_ref: self.into(),
len: self.sub_segments.len(),
}
}
}
/// A [`PlaceExpr`] with flags, e.g. whether it is used, bound, an instance attribute, etc.
#[derive(Eq, PartialEq, Debug)]
pub struct PlaceExprWithFlags {
pub(crate) expr: PlaceExpr,
flags: PlaceFlags,
}
impl std::fmt::Display for PlaceExprWithFlags {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.expr.fmt(f)
}
}
impl PlaceExprWithFlags {
pub(crate) fn new(expr: PlaceExpr) -> Self {
PlaceExprWithFlags {
expr,
flags: PlaceFlags::empty(),
}
}
fn name(name: Name) -> Self {
PlaceExprWithFlags {
expr: PlaceExpr::name(name),
flags: PlaceFlags::empty(),
}
}
fn insert_flags(&mut self, flags: PlaceFlags) {
self.flags.insert(flags);
}
pub(super) fn mark_instance_attribute(&mut self) {
self.flags.insert(PlaceFlags::IS_INSTANCE_ATTRIBUTE);
}
/// If the place expression has the form `<NAME>.<MEMBER>`
/// (meaning it *may* be an instance attribute),
/// return `Some(<MEMBER>)`. Else, return `None`.
@ -202,8 +263,8 @@ impl PlaceExpr {
/// parameter of the method (i.e. `self`). To answer those questions,
/// use [`Self::as_instance_attribute`].
pub(super) fn as_instance_attribute_candidate(&self) -> Option<&Name> {
if self.sub_segments.len() == 1 {
self.sub_segments[0].as_member()
if self.expr.sub_segments.len() == 1 {
self.expr.sub_segments[0].as_member()
} else {
None
}
@ -227,6 +288,16 @@ impl PlaceExpr {
self.as_instance_attribute().map(Name::as_str) == Some(name)
}
/// Return `Some(<ATTRIBUTE>)` if the place expression is an instance attribute.
pub(crate) fn as_instance_attribute(&self) -> Option<&Name> {
if self.is_instance_attribute() {
debug_assert!(self.as_instance_attribute_candidate().is_some());
self.as_instance_attribute_candidate()
} else {
None
}
}
/// Is the place an instance attribute?
pub(crate) fn is_instance_attribute(&self) -> bool {
let is_instance_attribute = self.flags.contains(PlaceFlags::IS_INSTANCE_ATTRIBUTE);
@ -236,14 +307,12 @@ impl PlaceExpr {
is_instance_attribute
}
/// Return `Some(<ATTRIBUTE>)` if the place expression is an instance attribute.
pub(crate) fn as_instance_attribute(&self) -> Option<&Name> {
if self.is_instance_attribute() {
debug_assert!(self.as_instance_attribute_candidate().is_some());
self.as_instance_attribute_candidate()
} else {
None
}
pub(crate) fn is_name(&self) -> bool {
self.expr.is_name()
}
pub(crate) fn is_member(&self) -> bool {
self.expr.is_member()
}
/// Is the place used in its containing scope?
@ -261,56 +330,58 @@ impl PlaceExpr {
self.flags.contains(PlaceFlags::IS_DECLARED)
}
/// Is the place just a name?
pub fn is_name(&self) -> bool {
self.sub_segments.is_empty()
pub(crate) fn as_name(&self) -> Option<&Name> {
self.expr.as_name()
}
pub fn is_name_and(&self, f: impl FnOnce(&str) -> bool) -> bool {
self.is_name() && f(&self.root_name)
pub(crate) fn expect_name(&self) -> &Name {
self.expr.expect_name()
}
}
/// Does the place expression have the form `<object>.member`?
pub fn is_member(&self) -> bool {
self.sub_segments
.last()
.is_some_and(|last| last.as_member().is_some())
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
pub struct PlaceExprRef<'a> {
pub(crate) root_name: &'a Name,
pub(crate) sub_segments: &'a [PlaceExprSubSegment],
}
impl PartialEq<PlaceExpr> for PlaceExprRef<'_> {
fn eq(&self, other: &PlaceExpr) -> bool {
self.root_name == &other.root_name && self.sub_segments == &other.sub_segments[..]
}
}
pub(crate) fn segments(&self) -> PlaceSegments {
PlaceSegments {
root_name: Some(&self.root_name),
sub_segments: &self.sub_segments,
}
impl PartialEq<PlaceExprRef<'_>> for PlaceExpr {
fn eq(&self, other: &PlaceExprRef<'_>) -> bool {
&self.root_name == other.root_name && &self.sub_segments[..] == other.sub_segments
}
}
// TODO: Ideally this would iterate PlaceSegments instead of RootExprs, both to reduce
// allocation and to avoid having both flagged and non-flagged versions of PlaceExprs.
fn root_exprs(&self) -> RootExprs<'_> {
RootExprs {
expr: self,
len: self.sub_segments.len(),
impl<'e> From<&'e PlaceExpr> for PlaceExprRef<'e> {
fn from(expr: &'e PlaceExpr) -> Self {
PlaceExprRef {
root_name: &expr.root_name,
sub_segments: &expr.sub_segments,
}
}
}
struct RootExprs<'e> {
expr: &'e PlaceExpr,
expr_ref: PlaceExprRef<'e>,
len: usize,
}
impl Iterator for RootExprs<'_> {
type Item = PlaceExpr;
impl<'e> Iterator for RootExprs<'e> {
type Item = PlaceExprRef<'e>;
fn next(&mut self) -> Option<Self::Item> {
if self.len == 0 {
return None;
}
self.len -= 1;
Some(PlaceExpr {
root_name: self.expr.root_name.clone(),
sub_segments: self.expr.sub_segments[..self.len].iter().cloned().collect(),
flags: PlaceFlags::empty(),
Some(PlaceExprRef {
root_name: self.expr_ref.root_name,
sub_segments: &self.expr_ref.sub_segments[..self.len],
})
}
}
@ -333,41 +404,6 @@ bitflags! {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PlaceSegment<'a> {
/// A first segment of a place expression (root name), e.g. `x` in `x.y.z[0]`.
Name(&'a ast::name::Name),
Member(&'a ast::name::Name),
IntSubscript(&'a ast::Int),
StringSubscript(&'a str),
}
#[derive(Debug, PartialEq, Eq)]
pub struct PlaceSegments<'a> {
root_name: Option<&'a ast::name::Name>,
sub_segments: &'a [PlaceExprSubSegment],
}
impl<'a> Iterator for PlaceSegments<'a> {
type Item = PlaceSegment<'a>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(name) = self.root_name.take() {
return Some(PlaceSegment::Name(name));
}
if self.sub_segments.is_empty() {
return None;
}
let segment = &self.sub_segments[0];
self.sub_segments = &self.sub_segments[1..];
Some(match segment {
PlaceExprSubSegment::Member(name) => PlaceSegment::Member(name),
PlaceExprSubSegment::IntSubscript(int) => PlaceSegment::IntSubscript(int),
PlaceExprSubSegment::StringSubscript(string) => PlaceSegment::StringSubscript(string),
})
}
}
/// ID that uniquely identifies a place in a file.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct FilePlaceId {
@ -575,7 +611,7 @@ impl ScopeKind {
#[derive(Default, salsa::Update)]
pub struct PlaceTable {
/// The place expressions in this scope.
places: IndexVec<ScopedPlaceId, PlaceExpr>,
places: IndexVec<ScopedPlaceId, PlaceExprWithFlags>,
/// The set of places.
place_set: PlaceSet,
@ -586,7 +622,7 @@ impl PlaceTable {
self.places.shrink_to_fit();
}
pub(crate) fn place_expr(&self, place_id: impl Into<ScopedPlaceId>) -> &PlaceExpr {
pub(crate) fn place_expr(&self, place_id: impl Into<ScopedPlaceId>) -> &PlaceExprWithFlags {
&self.places[place_id.into()]
}
@ -594,10 +630,10 @@ impl PlaceTable {
pub(crate) fn root_place_exprs(
&self,
place_expr: &PlaceExpr,
) -> impl Iterator<Item = &PlaceExpr> {
) -> impl Iterator<Item = &PlaceExprWithFlags> {
place_expr
.root_exprs()
.filter_map(|place_expr| self.place_by_expr(&place_expr))
.filter_map(|place_expr| self.place_by_expr(place_expr))
}
#[expect(unused)]
@ -605,11 +641,11 @@ impl PlaceTable {
self.places.indices()
}
pub fn places(&self) -> impl Iterator<Item = &PlaceExpr> {
pub fn places(&self) -> impl Iterator<Item = &PlaceExprWithFlags> {
self.places.iter()
}
pub fn symbols(&self) -> impl Iterator<Item = &PlaceExpr> {
pub fn symbols(&self) -> impl Iterator<Item = &PlaceExprWithFlags> {
self.places().filter(|place_expr| place_expr.is_name())
}
@ -620,19 +656,16 @@ impl PlaceTable {
/// Returns the place named `name`.
#[allow(unused)] // used in tests
pub(crate) fn place_by_name(&self, name: &str) -> Option<&PlaceExpr> {
pub(crate) fn place_by_name(&self, name: &str) -> Option<&PlaceExprWithFlags> {
let id = self.place_id_by_name(name)?;
Some(self.place_expr(id))
}
/// Returns the flagged place by the unflagged place expression.
///
/// TODO: Ideally this would take a [`PlaceSegments`] instead of [`PlaceExpr`], to avoid the
/// awkward distinction between "flagged" (canonical) and unflagged [`PlaceExpr`]; in that
/// world, we would only create [`PlaceExpr`] in semantic indexing; in type inference we'd
/// create [`PlaceSegments`] if we need to look up a [`PlaceExpr`]. The [`PlaceTable`] would
/// need to gain the ability to hash and look up by a [`PlaceSegments`].
pub(crate) fn place_by_expr(&self, place_expr: &PlaceExpr) -> Option<&PlaceExpr> {
/// Returns the flagged place.
pub(crate) fn place_by_expr<'e>(
&self,
place_expr: impl Into<PlaceExprRef<'e>>,
) -> Option<&PlaceExprWithFlags> {
let id = self.place_id_by_expr(place_expr)?;
Some(self.place_expr(id))
}
@ -650,12 +683,16 @@ impl PlaceTable {
}
/// Returns the [`ScopedPlaceId`] of the place expression.
pub(crate) fn place_id_by_expr(&self, place_expr: &PlaceExpr) -> Option<ScopedPlaceId> {
pub(crate) fn place_id_by_expr<'e>(
&self,
place_expr: impl Into<PlaceExprRef<'e>>,
) -> Option<ScopedPlaceId> {
let place_expr = place_expr.into();
let (id, ()) = self
.place_set
.raw_entry()
.from_hash(Self::hash_place_expr(place_expr), |id| {
self.place_expr(*id).segments() == place_expr.segments()
self.place_expr(*id).expr == place_expr
})?;
Some(*id)
@ -673,10 +710,12 @@ impl PlaceTable {
hasher.finish()
}
fn hash_place_expr(place_expr: &PlaceExpr) -> u64 {
fn hash_place_expr<'e>(place_expr: impl Into<PlaceExprRef<'e>>) -> u64 {
let place_expr = place_expr.into();
let mut hasher = FxHasher::default();
place_expr.root_name().as_str().hash(&mut hasher);
for segment in &place_expr.sub_segments {
place_expr.root_name.as_str().hash(&mut hasher);
for segment in place_expr.sub_segments {
match segment {
PlaceExprSubSegment::Member(name) => name.hash(&mut hasher),
PlaceExprSubSegment::IntSubscript(int) => int.hash(&mut hasher),
@ -725,11 +764,11 @@ impl PlaceTableBuilder {
match entry {
RawEntryMut::Occupied(entry) => (*entry.key(), false),
RawEntryMut::Vacant(entry) => {
let symbol = PlaceExpr::name(name);
let symbol = PlaceExprWithFlags::name(name);
let id = self.table.places.push(symbol);
entry.insert_with_hasher(hash, id, (), |id| {
PlaceTable::hash_place_expr(&self.table.places[*id])
PlaceTable::hash_place_expr(&self.table.places[*id].expr)
});
let new_id = self.associated_place_ids.push(vec![]);
debug_assert_eq!(new_id, id);
@ -738,23 +777,25 @@ impl PlaceTableBuilder {
}
}
pub(super) fn add_place(&mut self, place_expr: PlaceExpr) -> (ScopedPlaceId, bool) {
let hash = PlaceTable::hash_place_expr(&place_expr);
let entry = self.table.place_set.raw_entry_mut().from_hash(hash, |id| {
self.table.places[*id].segments() == place_expr.segments()
});
pub(super) fn add_place(&mut self, place_expr: PlaceExprWithFlags) -> (ScopedPlaceId, bool) {
let hash = PlaceTable::hash_place_expr(&place_expr.expr);
let entry = self
.table
.place_set
.raw_entry_mut()
.from_hash(hash, |id| self.table.places[*id].expr == place_expr.expr);
match entry {
RawEntryMut::Occupied(entry) => (*entry.key(), false),
RawEntryMut::Vacant(entry) => {
let id = self.table.places.push(place_expr);
entry.insert_with_hasher(hash, id, (), |id| {
PlaceTable::hash_place_expr(&self.table.places[*id])
PlaceTable::hash_place_expr(&self.table.places[*id].expr)
});
let new_id = self.associated_place_ids.push(vec![]);
debug_assert_eq!(new_id, id);
for root in self.table.places[id].root_exprs() {
if let Some(root_id) = self.table.place_id_by_expr(&root) {
for root in self.table.places[id].expr.root_exprs() {
if let Some(root_id) = self.table.place_id_by_expr(root) {
self.associated_place_ids[root_id].push(id);
}
}
@ -775,7 +816,7 @@ impl PlaceTableBuilder {
self.table.places[id].insert_flags(PlaceFlags::IS_USED);
}
pub(super) fn places(&self) -> impl Iterator<Item = &PlaceExpr> {
pub(super) fn places(&self) -> impl Iterator<Item = &PlaceExprWithFlags> {
self.table.places()
}
@ -783,7 +824,7 @@ impl PlaceTableBuilder {
self.table.place_id_by_expr(place_expr)
}
pub(super) fn place_expr(&self, place_id: impl Into<ScopedPlaceId>) -> &PlaceExpr {
pub(super) fn place_expr(&self, place_id: impl Into<ScopedPlaceId>) -> &PlaceExprWithFlags {
self.table.place_expr(place_id)
}

View file

@ -237,19 +237,21 @@ use self::place_state::{
LiveDeclarationsIterator, PlaceState, ScopedDefinitionId,
};
use crate::node_key::NodeKey;
use crate::semantic_index::EagerSnapshotResult;
use crate::semantic_index::ast_ids::ScopedUseId;
use crate::semantic_index::definition::{Definition, DefinitionState};
use crate::semantic_index::narrowing_constraints::{
ConstraintKey, NarrowingConstraints, NarrowingConstraintsBuilder, NarrowingConstraintsIterator,
};
use crate::semantic_index::place::{FileScopeId, PlaceExpr, ScopeKind, ScopedPlaceId};
use crate::semantic_index::place::{
FileScopeId, PlaceExpr, PlaceExprWithFlags, ScopeKind, ScopedPlaceId,
};
use crate::semantic_index::predicate::{
Predicate, Predicates, PredicatesBuilder, ScopedPredicateId, StarImportPlaceholderPredicate,
};
use crate::semantic_index::reachability_constraints::{
ReachabilityConstraints, ReachabilityConstraintsBuilder, ScopedReachabilityConstraintId,
};
use crate::semantic_index::{EagerSnapshotResult, SemanticIndex};
use crate::types::{IntersectionBuilder, Truthiness, Type, infer_narrowing_constraint};
mod place_state;
@ -320,6 +322,11 @@ pub(crate) struct UseDefMap<'db> {
end_of_scope_reachability: ScopedReachabilityConstraintId,
}
pub(crate) enum ApplicableConstraints<'map, 'db> {
UnboundBinding(ConstraintsIterator<'map, 'db>),
ConstrainedBindings(BindingWithConstraintsIterator<'map, 'db>),
}
impl<'db> UseDefMap<'db> {
pub(crate) fn bindings_at_use(
&self,
@ -328,19 +335,33 @@ impl<'db> UseDefMap<'db> {
self.bindings_iterator(&self.bindings_by_use[use_id])
}
pub(crate) fn narrowing_constraints_at_use(
pub(crate) fn applicable_constraints(
&self,
constraint_key: ConstraintKey,
) -> ConstraintsIterator<'_, 'db> {
let constraint = match constraint_key {
ConstraintKey::NarrowingConstraint(constraint) => constraint,
ConstraintKey::UseId(use_id) => {
self.bindings_by_use[use_id].unbound_narrowing_constraint()
enclosing_scope: FileScopeId,
expr: &PlaceExpr,
index: &'db SemanticIndex,
) -> ApplicableConstraints<'_, 'db> {
match constraint_key {
ConstraintKey::NarrowingConstraint(constraint) => {
ApplicableConstraints::UnboundBinding(ConstraintsIterator {
predicates: &self.predicates,
constraint_ids: self.narrowing_constraints.iter_predicates(constraint),
})
}
ConstraintKey::EagerNestedScope(nested_scope) => {
let EagerSnapshotResult::FoundBindings(bindings) =
index.eager_snapshot(enclosing_scope, expr, nested_scope)
else {
unreachable!(
"The result of `SemanticIndex::eager_snapshot` must be `FoundBindings`"
)
};
ApplicableConstraints::ConstrainedBindings(bindings)
}
ConstraintKey::UseId(use_id) => {
ApplicableConstraints::ConstrainedBindings(self.bindings_at_use(use_id))
}
};
ConstraintsIterator {
predicates: &self.predicates,
constraint_ids: self.narrowing_constraints.iter_predicates(constraint),
}
}
@ -884,7 +905,7 @@ impl<'db> UseDefMapBuilder<'db> {
&mut self,
enclosing_place: ScopedPlaceId,
scope: ScopeKind,
enclosing_place_expr: &PlaceExpr,
enclosing_place_expr: &PlaceExprWithFlags,
) -> ScopedEagerSnapshotId {
// Names bound in class scopes are never visible to nested scopes (but attributes/subscripts are visible),
// so we never need to save eager scope bindings in a class scope.

View file

@ -60,7 +60,7 @@ use crate::semantic_index::ast_ids::{
};
use crate::semantic_index::definition::{
AnnotatedAssignmentDefinitionKind, AssignmentDefinitionKind, ComprehensionDefinitionKind,
Definition, DefinitionKind, DefinitionNodeKey, ExceptHandlerDefinitionKind,
Definition, DefinitionKind, DefinitionNodeKey, DefinitionState, ExceptHandlerDefinitionKind,
ForStmtDefinitionKind, TargetKind, WithItemDefinitionKind,
};
use crate::semantic_index::expression::{Expression, ExpressionKind};
@ -68,7 +68,9 @@ use crate::semantic_index::narrowing_constraints::ConstraintKey;
use crate::semantic_index::place::{
FileScopeId, NodeWithScopeKind, NodeWithScopeRef, PlaceExpr, ScopeId, ScopeKind, ScopedPlaceId,
};
use crate::semantic_index::{EagerSnapshotResult, SemanticIndex, place_table, semantic_index};
use crate::semantic_index::{
ApplicableConstraints, EagerSnapshotResult, SemanticIndex, place_table, semantic_index,
};
use crate::types::call::{
Argument, Binding, Bindings, CallArgumentTypes, CallArguments, CallError,
};
@ -746,6 +748,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.expression_type(expr.scoped_expression_id(self.db(), self.scope()))
}
fn try_expression_type(&self, expr: &ast::Expr) -> Option<Type<'db>> {
self.types
.try_expression_type(expr.scoped_expression_id(self.db(), self.scope()))
}
/// Get the type of an expression from any scope in the same file.
///
/// If the expression is in the current scope, and we are inferring the entire scope, just look
@ -1510,13 +1517,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let global_use_def_map = self.index.use_def_map(FileScopeId::global());
let place_id = binding.place(self.db());
let expr = place_table.place_expr(place_id);
let place = place_table.place_expr(place_id);
let skip_non_global_scopes = self.skip_non_global_scopes(file_scope_id, place_id);
let declarations = if skip_non_global_scopes {
match self
.index
.place_table(FileScopeId::global())
.place_id_by_expr(expr)
.place_id_by_expr(&place.expr)
{
Some(id) => global_use_def_map.public_declarations(id),
// This case is a syntax error (load before global declaration) but ignore that here
@ -1527,18 +1534,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
};
let declared_ty = place_from_declarations(self.db(), declarations)
.and_then(|place| {
Ok(if matches!(place.place, Place::Type(_, Boundness::Bound)) {
place
} else if skip_non_global_scopes
|| self.scope().file_scope_id(self.db()).is_global()
{
let module_type_declarations =
module_type_implicit_global_declaration(self.db(), expr)?;
place.or_fall_back_to(self.db(), || module_type_declarations)
} else {
place
})
.and_then(|place_and_quals| {
Ok(
if matches!(place_and_quals.place, Place::Type(_, Boundness::Bound)) {
place_and_quals
} else if skip_non_global_scopes
|| self.scope().file_scope_id(self.db()).is_global()
{
let module_type_declarations =
module_type_implicit_global_declaration(self.db(), &place.expr)?;
place_and_quals.or_fall_back_to(self.db(), || module_type_declarations)
} else {
place_and_quals
},
)
})
.map(
|PlaceAndQualifiers {
@ -1576,10 +1585,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
)
.unwrap_or_else(|(ty, conflicting)| {
// TODO point out the conflicting declarations in the diagnostic?
let expr = place_table.place_expr(binding.place(db));
let place = place_table.place_expr(binding.place(db));
if let Some(builder) = self.context.report_lint(&CONFLICTING_DECLARATIONS, node) {
builder.into_diagnostic(format_args!(
"Conflicting declared types for `{expr}`: {}",
"Conflicting declared types for `{place}`: {}",
conflicting.display(db)
));
}
@ -1590,6 +1599,54 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// allow declarations to override inference in case of invalid assignment
bound_ty = declared_ty;
}
// In the following cases, the bound type may not be the same as the RHS value type.
if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = node {
let value_ty = self
.try_expression_type(value)
.unwrap_or_else(|| self.infer_maybe_standalone_expression(value));
// If the member is a data descriptor, the RHS value may differ from the value actually assigned.
if value_ty
.class_member(db, attr.id.clone())
.place
.ignore_possibly_unbound()
.is_some_and(|ty| ty.may_be_data_descriptor(db))
{
bound_ty = declared_ty;
}
} else if let AnyNodeRef::ExprSubscript(ast::ExprSubscript { value, .. }) = node {
let value_ty = self
.try_expression_type(value)
.unwrap_or_else(|| self.infer_expression(value));
// Arbitrary `__getitem__`/`__setitem__` methods on a class do not
// necessarily guarantee that the passed-in value for `__setitem__` is stored and
// can be retrieved unmodified via `__getitem__`. Therefore, we currently only
// perform assignment-based narrowing on a few built-in classes (`list`, `dict`,
// `bytesarray`, `TypedDict` and `collections` types) where we are confident that
// this kind of narrowing can be performed soundly. This is the same approach as
// pyright. TODO: Other standard library classes may also be considered safe. Also,
// subclasses of these safe classes that do not override `__getitem__/__setitem__`
// may be considered safe.
let safe_mutable_classes = [
KnownClass::List.to_instance(db),
KnownClass::Dict.to_instance(db),
KnownClass::Bytearray.to_instance(db),
KnownClass::DefaultDict.to_instance(db),
SpecialFormType::ChainMap.instance_fallback(db),
SpecialFormType::Counter.instance_fallback(db),
SpecialFormType::Deque.instance_fallback(db),
SpecialFormType::OrderedDict.instance_fallback(db),
SpecialFormType::TypedDict.instance_fallback(db),
];
if safe_mutable_classes.iter().all(|safe_mutable_class| {
!value_ty.is_equivalent_to(db, *safe_mutable_class)
&& value_ty
.generic_origin(db)
.zip(safe_mutable_class.generic_origin(db))
.is_none_or(|(l, r)| l != r)
}) {
bound_ty = declared_ty;
}
}
self.types.bindings.insert(binding, bound_ty);
}
@ -1624,9 +1681,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Fallback to bindings declared on `types.ModuleType` if it's a global symbol
let scope = self.scope().file_scope_id(self.db());
let place_table = self.index.place_table(scope);
let expr = place_table.place_expr(declaration.place(self.db()));
if scope.is_global() && expr.is_name() {
module_type_implicit_global_symbol(self.db(), expr.expect_name())
let place = place_table.place_expr(declaration.place(self.db()));
if scope.is_global() && place.is_name() {
module_type_implicit_global_symbol(self.db(), place.expect_name())
} else {
Place::Unbound.into()
}
@ -1677,9 +1734,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let file_scope_id = self.scope().file_scope_id(self.db());
if file_scope_id.is_global() {
let place_table = self.index.place_table(file_scope_id);
let expr = place_table.place_expr(definition.place(self.db()));
let place = place_table.place_expr(definition.place(self.db()));
if let Some(module_type_implicit_declaration) =
module_type_implicit_global_declaration(self.db(), expr)
module_type_implicit_global_declaration(self.db(), &place.expr)
.ok()
.and_then(|place| place.place.ignore_possibly_unbound())
{
@ -1691,11 +1748,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.context.report_lint(&INVALID_DECLARATION, node)
{
let mut diagnostic = builder.into_diagnostic(format_args!(
"Cannot shadow implicit global attribute `{expr}` with declaration of type `{}`",
"Cannot shadow implicit global attribute `{place}` with declaration of type `{}`",
declared_type.display(self.db())
));
diagnostic.info(format_args!("The global symbol `{}` must always have a type assignable to `{}`",
expr,
place,
module_type_implicit_declaration.display(self.db())
));
}
@ -5920,7 +5977,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
// Perform narrowing with applicable constraints between the current scope and the enclosing scope.
fn narrow_with_applicable_constraints(
fn narrow_place_with_applicable_constraints(
&self,
expr: &PlaceExpr,
mut ty: Type<'db>,
@ -5929,11 +5986,69 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let db = self.db();
for (enclosing_scope_file_id, constraint_key) in constraint_keys {
let use_def = self.index.use_def_map(*enclosing_scope_file_id);
let constraints = use_def.narrowing_constraints_at_use(*constraint_key);
let place_table = self.index.place_table(*enclosing_scope_file_id);
let place = place_table.place_id_by_expr(expr).unwrap();
ty = constraints.narrow(db, ty, place);
match use_def.applicable_constraints(
*constraint_key,
*enclosing_scope_file_id,
expr,
self.index,
) {
ApplicableConstraints::UnboundBinding(constraint) => {
ty = constraint.narrow(db, ty, place);
}
// Performs narrowing based on constrained bindings.
// This handling must be performed even if narrowing is attempted and failed using `infer_place_load`.
// The result of `infer_place_load` can be applied as is only when its boundness is `Bound`.
// For example, this handling is required in the following case:
// ```python
// class C:
// x: int | None = None
// c = C()
// # c.x: int | None = <unbound>
// if c.x is None:
// c.x = 1
// # else: c.x: int = <unbound>
// # `c.x` is not definitely bound here
// reveal_type(c.x) # revealed: int
// ```
ApplicableConstraints::ConstrainedBindings(bindings) => {
let reachability_constraints = bindings.reachability_constraints;
let predicates = bindings.predicates;
let mut union = UnionBuilder::new(db);
for binding in bindings {
let static_reachability = reachability_constraints.evaluate(
db,
predicates,
binding.reachability_constraint,
);
if static_reachability.is_always_false() {
continue;
}
match binding.binding {
DefinitionState::Defined(definition) => {
let binding_ty = binding_type(db, definition);
union = union.add(
binding.narrowing_constraint.narrow(db, binding_ty, place),
);
}
DefinitionState::Undefined | DefinitionState::Deleted => {
union =
union.add(binding.narrowing_constraint.narrow(db, ty, place));
}
}
}
// If there are no visible bindings, the union becomes `Never`.
// Since an unbound binding is recorded even for an undefined place,
// this can only happen if the code is unreachable
// and therefore it is correct to set the result to `Never`.
let union = union.build();
if union.is_assignable_to(db, ty) {
ty = union;
}
}
}
}
ty
}
@ -5956,7 +6071,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// These are looked up as attributes on `types.ModuleType`.
.or_fall_back_to(db, || {
module_type_implicit_global_symbol(db, symbol_name).map_type(|ty| {
self.narrow_with_applicable_constraints(&expr, ty, &constraint_keys)
self.narrow_place_with_applicable_constraints(&expr, ty, &constraint_keys)
})
})
// Not found in globals? Fallback to builtins
@ -6028,7 +6143,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
/// Infer the type of a place expression, assuming a load context.
/// Infer the type of a place expression from definitions, assuming a load context.
/// This method also returns the [`ConstraintKey`]s for each scope associated with `expr`,
/// which is used to narrow by condition rather than by assignment.
fn infer_place_load(
&self,
expr: &PlaceExpr,
@ -6041,6 +6158,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let mut constraint_keys = vec![];
let (local_scope_place, use_id) = self.infer_local_place_load(expr, expr_ref);
if let Some(use_id) = use_id {
constraint_keys.push((file_scope_id, ConstraintKey::UseId(use_id)));
}
let place = PlaceAndQualifiers::from(local_scope_place).or_fall_back_to(db, || {
let has_bindings_in_this_scope = match place_table.place_by_expr(expr) {
@ -6081,7 +6201,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
for root_expr in place_table.root_place_exprs(expr) {
let mut expr_ref = expr_ref;
for _ in 0..(expr.sub_segments().len() - root_expr.sub_segments().len()) {
for _ in 0..(expr.sub_segments().len() - root_expr.expr.sub_segments().len()) {
match expr_ref {
ast::ExprRef::Attribute(attribute) => {
expr_ref = ast::ExprRef::from(&attribute.value);
@ -6092,16 +6212,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
_ => unreachable!(),
}
}
let (parent_place, _use_id) = self.infer_local_place_load(root_expr, expr_ref);
let (parent_place, _use_id) =
self.infer_local_place_load(&root_expr.expr, expr_ref);
if let Place::Type(_, _) = parent_place {
return Place::Unbound.into();
}
}
if let Some(use_id) = use_id {
constraint_keys.push((file_scope_id, ConstraintKey::UseId(use_id)));
}
// Walk up parent scopes looking for a possible enclosing scope that may have a
// definition of this name visible to us (would be `LOAD_DEREF` at runtime.)
// Note that we skip the scope containing the use that we are resolving, since we
@ -6144,15 +6261,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
{
continue;
}
return place_from_bindings(db, bindings)
.map_type(|ty| {
self.narrow_with_applicable_constraints(
expr,
ty,
&constraint_keys,
)
})
.into();
let place = place_from_bindings(db, bindings).map_type(|ty| {
self.narrow_place_with_applicable_constraints(
expr,
ty,
&constraint_keys,
)
});
constraint_keys.push((
enclosing_scope_file_id,
ConstraintKey::EagerNestedScope(file_scope_id),
));
return place.into();
}
// There are no visible bindings / constraint here.
// Don't fall back to non-eager place resolution.
@ -6163,7 +6283,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
{
if enclosing_root_place.is_bound() {
if let Place::Type(_, _) =
place(db, enclosing_scope_id, enclosing_root_place).place
place(db, enclosing_scope_id, &enclosing_root_place.expr)
.place
{
return Place::Unbound.into();
}
@ -6190,7 +6311,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// isn't bound in that scope, we should get an unbound name, not continue
// falling back to other scopes / globals / builtins.
return place(db, enclosing_scope_id, expr).map_type(|ty| {
self.narrow_with_applicable_constraints(expr, ty, &constraint_keys)
self.narrow_place_with_applicable_constraints(expr, ty, &constraint_keys)
});
}
}
@ -6215,15 +6336,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
));
}
EagerSnapshotResult::FoundBindings(bindings) => {
return place_from_bindings(db, bindings)
.map_type(|ty| {
self.narrow_with_applicable_constraints(
expr,
ty,
&constraint_keys,
)
})
.into();
let place = place_from_bindings(db, bindings).map_type(|ty| {
self.narrow_place_with_applicable_constraints(
expr,
ty,
&constraint_keys,
)
});
constraint_keys.push((
FileScopeId::global(),
ConstraintKey::EagerNestedScope(file_scope_id),
));
return place.into();
}
// There are no visible bindings / constraint here.
EagerSnapshotResult::NotFound => {
@ -6238,7 +6362,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
};
explicit_global_symbol(db, self.file(), name).map_type(|ty| {
self.narrow_with_applicable_constraints(expr, ty, &constraint_keys)
self.narrow_place_with_applicable_constraints(expr, ty, &constraint_keys)
})
})
});
@ -6302,6 +6426,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
fn narrow_expr_with_applicable_constraints<'r>(
&self,
target: impl Into<ast::ExprRef<'r>>,
target_ty: Type<'db>,
constraint_keys: &[(FileScopeId, ConstraintKey)],
) -> Type<'db> {
let target = target.into();
if let Ok(place_expr) = PlaceExpr::try_from(target) {
self.narrow_place_with_applicable_constraints(&place_expr, target_ty, constraint_keys)
} else {
target_ty
}
}
/// Infer the type of a [`ast::ExprAttribute`] expression, assuming a load context.
fn infer_attribute_load(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> {
let ast::ExprAttribute {
@ -6314,27 +6453,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let value_type = self.infer_maybe_standalone_expression(value);
let db = self.db();
let mut constraint_keys = vec![];
// If `attribute` is a valid reference, we attempt type narrowing by assignment.
let mut assigned_type = None;
if let Ok(place_expr) = PlaceExpr::try_from(attribute) {
let member = value_type.class_member(db, attr.id.clone());
// If the member is a data descriptor, the value most recently assigned
// to the attribute may not necessarily be obtained here.
if member
.place
.ignore_possibly_unbound()
.is_none_or(|ty| !ty.may_be_data_descriptor(db))
{
let (resolved, _) =
self.infer_place_load(&place_expr, ast::ExprRef::Attribute(attribute));
if let Place::Type(ty, Boundness::Bound) = resolved.place {
return ty;
}
let (resolved, keys) =
self.infer_place_load(&place_expr, ast::ExprRef::Attribute(attribute));
constraint_keys.extend(keys);
if let Place::Type(ty, Boundness::Bound) = resolved.place {
assigned_type = Some(ty);
}
}
value_type
let resolved_type = value_type
.member(db, &attr.id)
.map_type(|ty| self.narrow_expr_with_applicable_constraints(attribute, ty, &constraint_keys))
.unwrap_with_diagnostic(|lookup_error| match lookup_error {
LookupError::Unbound(_) => {
let report_unresolved_attribute = self.is_reachable(attribute);
@ -6394,7 +6527,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
type_when_bound
}
}).inner_type()
})
.inner_type();
// Even if we can obtain the attribute type based on the assignments, we still perform default type inference
// (to report errors).
assigned_type.unwrap_or(resolved_type)
}
fn infer_attribute_expression(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> {
@ -7839,46 +7976,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
slice,
ctx: _,
} = subscript;
let db = self.db();
let value_ty = self.infer_expression(value);
let mut constraint_keys = vec![];
// If `value` is a valid reference, we attempt type narrowing by assignment.
if !value_ty.is_unknown() {
if let Ok(expr) = PlaceExpr::try_from(subscript) {
// Type narrowing based on assignment to a subscript expression is generally
// unsound, because arbitrary `__getitem__`/`__setitem__` methods on a class do not
// necessarily guarantee that the passed-in value for `__setitem__` is stored and
// can be retrieved unmodified via `__getitem__`. Therefore, we currently only
// perform assignment-based narrowing on a few built-in classes (`list`, `dict`,
// `bytesarray`, `TypedDict` and `collections` types) where we are confident that
// this kind of narrowing can be performed soundly. This is the same approach as
// pyright. TODO: Other standard library classes may also be considered safe. Also,
// subclasses of these safe classes that do not override `__getitem__/__setitem__`
// may be considered safe.
let safe_mutable_classes = [
KnownClass::List.to_instance(db),
KnownClass::Dict.to_instance(db),
KnownClass::Bytearray.to_instance(db),
KnownClass::DefaultDict.to_instance(db),
SpecialFormType::ChainMap.instance_fallback(db),
SpecialFormType::Counter.instance_fallback(db),
SpecialFormType::Deque.instance_fallback(db),
SpecialFormType::OrderedDict.instance_fallback(db),
SpecialFormType::TypedDict.instance_fallback(db),
];
if safe_mutable_classes.iter().any(|safe_mutable_class| {
value_ty.is_equivalent_to(db, *safe_mutable_class)
|| value_ty
.generic_origin(db)
.zip(safe_mutable_class.generic_origin(db))
.is_some_and(|(l, r)| l == r)
}) {
let (place, _) =
self.infer_place_load(&expr, ast::ExprRef::Subscript(subscript));
if let Place::Type(ty, Boundness::Bound) = place.place {
self.infer_expression(slice);
return ty;
}
let (place, keys) =
self.infer_place_load(&expr, ast::ExprRef::Subscript(subscript));
constraint_keys.extend(keys);
if let Place::Type(ty, Boundness::Bound) = place.place {
// Even if we can obtain the subscript type based on the assignments, we still perform default type inference
// (to store the expression type and to report errors).
let slice_ty = self.infer_expression(slice);
self.infer_subscript_expression_types(value, value_ty, slice_ty);
return ty;
}
}
}
@ -7908,7 +8020,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
let slice_ty = self.infer_expression(slice);
self.infer_subscript_expression_types(value, value_ty, slice_ty)
let result_ty = self.infer_subscript_expression_types(value, value_ty, slice_ty);
self.narrow_expr_with_applicable_constraints(subscript, result_ty, &constraint_keys)
}
fn infer_explicit_class_specialization(

View file

@ -1,7 +1,7 @@
use crate::Db;
use crate::semantic_index::ast_ids::HasScopedExpressionId;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::place::{PlaceTable, ScopeId, ScopedPlaceId};
use crate::semantic_index::place::{PlaceExpr, PlaceTable, ScopeId, ScopedPlaceId};
use crate::semantic_index::place_table;
use crate::semantic_index::predicate::{
PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
@ -247,13 +247,12 @@ fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db,
}
}
fn expr_name(expr: &ast::Expr) -> Option<&ast::name::Name> {
fn place_expr(expr: &ast::Expr) -> Option<PlaceExpr> {
match expr {
ast::Expr::Named(ast::ExprNamed { target, .. }) => match target.as_ref() {
ast::Expr::Name(ast::ExprName { id, .. }) => Some(id),
_ => None,
},
ast::Expr::Name(ast::ExprName { id, .. }) => Some(id),
ast::Expr::Name(name) => Some(PlaceExpr::name(name.id.clone())),
ast::Expr::Attribute(attr) => PlaceExpr::try_from(attr).ok(),
ast::Expr::Subscript(subscript) => PlaceExpr::try_from(subscript).ok(),
ast::Expr::Named(named) => PlaceExpr::try_from(named.target.as_ref()).ok(),
_ => None,
}
}
@ -314,7 +313,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
match expression_node {
ast::Expr::Name(name) => Some(self.evaluate_expr_name(name, is_positive)),
ast::Expr::Name(_) | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) => {
self.evaluate_simple_expr(expression_node, is_positive)
}
ast::Expr::Compare(expr_compare) => {
self.evaluate_expr_compare(expr_compare, expression, is_positive)
}
@ -374,27 +375,27 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
}
#[track_caller]
fn expect_expr_name_symbol(&self, symbol: &str) -> ScopedPlaceId {
fn expect_place(&self, place_expr: &PlaceExpr) -> ScopedPlaceId {
self.places()
.place_id_by_name(symbol)
.expect("We should always have a symbol for every `Name` node")
.place_id_by_expr(place_expr)
.expect("We should always have a place for every `PlaceExpr`")
}
fn evaluate_expr_name(
fn evaluate_simple_expr(
&mut self,
expr_name: &ast::ExprName,
expr: &ast::Expr,
is_positive: bool,
) -> NarrowingConstraints<'db> {
let ast::ExprName { id, .. } = expr_name;
) -> Option<NarrowingConstraints<'db>> {
let target = place_expr(expr)?;
let place = self.expect_place(&target);
let symbol = self.expect_expr_name_symbol(id);
let ty = if is_positive {
Type::AlwaysFalsy.negate(self.db)
} else {
Type::AlwaysTruthy.negate(self.db)
};
NarrowingConstraints::from_iter([(symbol, ty)])
Some(NarrowingConstraints::from_iter([(place, ty)]))
}
fn evaluate_expr_named(
@ -402,11 +403,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
expr_named: &ast::ExprNamed,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
if let ast::Expr::Name(expr_name) = expr_named.target.as_ref() {
Some(self.evaluate_expr_name(expr_name, is_positive))
} else {
None
}
self.evaluate_simple_expr(&expr_named.target, is_positive)
}
fn evaluate_expr_eq(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
@ -598,7 +595,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool {
matches!(
expr,
ast::Expr::Name(_) | ast::Expr::Call(_) | ast::Expr::Named(_)
ast::Expr::Name(_)
| ast::Expr::Attribute(_)
| ast::Expr::Subscript(_)
| ast::Expr::Call(_)
| ast::Expr::Named(_)
)
}
@ -644,13 +645,16 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
last_rhs_ty = Some(rhs_ty);
match left {
ast::Expr::Name(_) | ast::Expr::Named(_) => {
if let Some(id) = expr_name(left) {
let symbol = self.expect_expr_name_symbol(id);
ast::Expr::Name(_)
| ast::Expr::Attribute(_)
| ast::Expr::Subscript(_)
| ast::Expr::Named(_) => {
if let Some(left) = place_expr(left) {
let op = if is_positive { *op } else { op.negate() };
if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) {
constraints.insert(symbol, ty);
let place = self.expect_place(&left);
constraints.insert(place, ty);
}
}
}
@ -674,9 +678,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
}
};
let id = match &**args {
[first] => match expr_name(first) {
Some(id) => id,
let target = match &**args {
[first] => match place_expr(first) {
Some(target) => target,
None => continue,
},
_ => continue,
@ -699,9 +703,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
.into_class_literal()
.is_some_and(|c| c.is_known(self.db, KnownClass::Type))
{
let symbol = self.expect_expr_name_symbol(id);
let place = self.expect_place(&target);
constraints.insert(
symbol,
place,
Type::instance(self.db, rhs_class.unknown_specialization(self.db)),
);
}
@ -754,9 +758,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let [first_arg, second_arg] = &*expr_call.arguments.args else {
return None;
};
let first_arg = expr_name(first_arg)?;
let first_arg = place_expr(first_arg)?;
let function = function_type.known(self.db)?;
let symbol = self.expect_expr_name_symbol(first_arg);
let place = self.expect_place(&first_arg);
if function == KnownFunction::HasAttr {
let attr = inference
@ -774,7 +778,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
);
return Some(NarrowingConstraints::from_iter([(
symbol,
place,
constraint.negate_if(self.db, !is_positive),
)]));
}
@ -788,7 +792,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
.generate_constraint(self.db, class_info_ty)
.map(|constraint| {
NarrowingConstraints::from_iter([(
symbol,
place,
constraint.negate_if(self.db, !is_positive),
)])
})
@ -814,15 +818,15 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
subject: Expression<'db>,
singleton: ast::Singleton,
) -> Option<NarrowingConstraints<'db>> {
let symbol = self
.expect_expr_name_symbol(&subject.node_ref(self.db, self.module).as_name_expr()?.id);
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);
let ty = match singleton {
ast::Singleton::None => Type::none(self.db),
ast::Singleton::True => Type::BooleanLiteral(true),
ast::Singleton::False => Type::BooleanLiteral(false),
};
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
Some(NarrowingConstraints::from_iter([(place, ty)]))
}
fn evaluate_match_pattern_class(
@ -830,11 +834,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
subject: Expression<'db>,
cls: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
let symbol = self
.expect_expr_name_symbol(&subject.node_ref(self.db, self.module).as_name_expr()?.id);
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);
let ty = infer_same_file_expression_type(self.db, cls, self.module).to_instance(self.db)?;
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
Some(NarrowingConstraints::from_iter([(place, ty)]))
}
fn evaluate_match_pattern_value(
@ -842,10 +847,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
subject: Expression<'db>,
value: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
let symbol = self
.expect_expr_name_symbol(&subject.node_ref(self.db, self.module).as_name_expr()?.id);
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);
let ty = infer_same_file_expression_type(self.db, value, self.module);
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
Some(NarrowingConstraints::from_iter([(place, ty)]))
}
fn evaluate_match_pattern_or(