From b01003f81d25c6296ecd3f2a18b13521110bb8d5 Mon Sep 17 00:00:00 2001 From: David Peter Date: Thu, 26 Jun 2025 12:24:40 +0200 Subject: [PATCH] [ty] Infer nonlocal types as unions of all reachable bindings (#18750) ## Summary This PR includes a behavioral change to how we infer types for public uses of symbols within a module. Where we would previously use the type that a use at the end of the scope would see, we now consider all reachable bindings and union the results: ```py x = None def f(): reveal_type(x) # previously `Unknown | Literal[1]`, now `Unknown | None | Literal[1]` f() x = 1 f() ``` This helps especially in cases where the the end of the scope is not reachable: ```py def outer(x: int): def inner(): reveal_type(x) # previously `Unknown`, now `int` raise ValueError ``` This PR also proposes to skip the boundness analysis of public uses. This is consistent with the "all reachable bindings" strategy, because the implicit `x = ` binding is also always reachable, and we would have to emit "possibly-unresolved" diagnostics for every public use otherwise. Changing this behavior allows common use-cases like the following to type check without any errors: ```py def outer(flag: bool): if flag: x = 1 def inner(): print(x) # previously: possibly-unresolved-reference, now: no error ``` closes https://github.com/astral-sh/ty/issues/210 closes https://github.com/astral-sh/ty/issues/607 closes https://github.com/astral-sh/ty/issues/699 ## Follow up It is now possible to resolve the following TODO, but I would like to do that as a follow-up, because it requires some changes to how we treat implicit attribute assignments, which could result in ecosystem changes that I'd like to see separately. https://github.com/astral-sh/ruff/blob/315fb0f3da4e5f2097a294c344025dec28ebf2a5/crates/ty_python_semantic/src/semantic_index/builder.rs#L1095-L1117 ## Ecosystem analysis [**Full report**](https://shark.fish/diff-public-types.html) * This change obviously removes a lot of `possibly-unresolved-reference` diagnostics (7818) because we do not analyze boundness for public uses of symbols inside modules anymore. * As the primary goal here, this change also removes a lot of false-positive `unresolved-reference` diagnostics (231) in scenarios like this: ```py def _(flag: bool): if flag: x = 1 def inner(): x raise ``` * This change also introduces some new false positives for cases like: ```py def _(): x = None x = "test" def inner(): x.upper() # Attribute `upper` on type `Unknown | None | Literal["test"]` is possibly unbound ``` We have test cases for these situations and it's plausible that we can improve this in a follow-up. ## Test Plan New Markdown tests --- .../resources/mdtest/declaration/error.md | 18 + .../resources/mdtest/public_types.md | 423 ++++++++++++++++++ .../resources/mdtest/scopes/builtin.md | 6 +- .../resources/mdtest/scopes/eager.md | 15 +- .../mdtest/statically_known_branches.md | 35 +- .../resources/mdtest/terminal_statements.md | 20 +- .../resources/mdtest/unreachable.md | 8 +- crates/ty_python_semantic/src/place.rs | 377 +++++++++++++--- .../ty_python_semantic/src/semantic_index.rs | 4 +- .../src/semantic_index/use_def.rs | 121 ++++- .../src/semantic_index/use_def/place_state.rs | 38 +- crates/ty_python_semantic/src/types.rs | 1 - crates/ty_python_semantic/src/types/class.rs | 10 +- .../src/types/ide_support.rs | 6 +- crates/ty_python_semantic/src/types/infer.rs | 44 +- .../src/types/protocol_class.rs | 20 +- .../src/types/signatures.rs | 8 +- 17 files changed, 983 insertions(+), 171 deletions(-) create mode 100644 crates/ty_python_semantic/resources/mdtest/public_types.md diff --git a/crates/ty_python_semantic/resources/mdtest/declaration/error.md b/crates/ty_python_semantic/resources/mdtest/declaration/error.md index 3f435d8e5d..424c04b0a9 100644 --- a/crates/ty_python_semantic/resources/mdtest/declaration/error.md +++ b/crates/ty_python_semantic/resources/mdtest/declaration/error.md @@ -32,6 +32,24 @@ def _(flag1: bool, flag2: bool): x = 1 # error: [conflicting-declarations] "Conflicting declared types for `x`: str, int" ``` +## Incompatible declarations with repeated types + +```py +def _(flag1: bool, flag2: bool, flag3: bool, flag4: bool): + if flag1: + x: str + elif flag2: + x: int + elif flag3: + x: int + elif flag4: + x: str + else: + x: bytes + + x = "a" # error: [conflicting-declarations] "Conflicting declared types for `x`: str, int, bytes" +``` + ## Incompatible declarations with bad assignment ```py diff --git a/crates/ty_python_semantic/resources/mdtest/public_types.md b/crates/ty_python_semantic/resources/mdtest/public_types.md new file mode 100644 index 0000000000..15536f1d04 --- /dev/null +++ b/crates/ty_python_semantic/resources/mdtest/public_types.md @@ -0,0 +1,423 @@ +# Public types + +## Basic + +The "public type" of a symbol refers to the type that is inferred in a nested scope for a symbol +defined in an outer enclosing scope. Since it is not generally possible to analyze the full control +flow of a program, we currently make the simplifying assumption that an inner scope (such as the +`inner` function below) could be executed at any position in the enclosing scope. The public type +should therefore be the union of all possible types that the symbol could have. + +In the following example, depending on when `inner()` is called, the type of `x` could either be `A` +or `B`: + +```py +class A: ... +class B: ... +class C: ... + +def outer() -> None: + x = A() + + def inner() -> None: + # TODO: We might ideally be able to eliminate `Unknown` from the union here since `x` resolves to an + # outer scope that is a function scope (as opposed to module global scope), and `x` is never declared + # nonlocal in a nested scope that also assigns to it. + reveal_type(x) # revealed: Unknown | A | B + # This call would observe `x` as `A`. + inner() + + x = B() + + # This call would observe `x` as `B`. + inner() +``` + +Similarly, if control flow in the outer scope can split, the public type of `x` should reflect that: + +```py +def outer(flag: bool) -> None: + x = A() + + def inner() -> None: + reveal_type(x) # revealed: Unknown | A | B | C + inner() + + if flag: + x = B() + + inner() + else: + x = C() + + inner() + + inner() +``` + +If a binding is not reachable, it is not considered in the public type: + +```py +def outer() -> None: + x = A() + + def inner() -> None: + reveal_type(x) # revealed: Unknown | A | C + inner() + + if False: + x = B() # this binding of `x` is unreachable + inner() + + x = C() + inner() + +def outer(flag: bool) -> None: + x = A() + + def inner() -> None: + reveal_type(x) # revealed: Unknown | A | C + inner() + + if flag: + return + + x = B() # this binding of `x` is unreachable + + x = C() + inner() +``` + +If a symbol is only conditionally bound, we do not raise any errors: + +```py +def outer(flag: bool) -> None: + if flag: + x = A() + + def inner() -> None: + reveal_type(x) # revealed: Unknown | A + inner() +``` + +In the future, we may try to be smarter about which bindings must or must not be a visible to a +given nested scope, depending where it is defined. In the above case, this shouldn't change the +behavior -- `x` is defined before `inner` in the same branch, so should be considered +definitely-bound for `inner`. But in other cases we may want to emit `possibly-unresolved-reference` +in future: + +```py +def outer(flag: bool) -> None: + if flag: + x = A() + + def inner() -> None: + # TODO: Ideally, we would emit a possibly-unresolved-reference error here. + reveal_type(x) # revealed: Unknown | A + inner() +``` + +The public type is available, even if the end of the outer scope is unreachable. This is a +regression test. A previous version of ty used the end-of-scope position to determine the public +type, which would have resulted in incorrect type inference here: + +```py +def outer() -> None: + x = A() + + def inner() -> None: + reveal_type(x) # revealed: Unknown | A + inner() + + return + # unreachable + +def outer(flag: bool) -> None: + x = A() + + def inner() -> None: + reveal_type(x) # revealed: Unknown | A | B + if flag: + x = B() + inner() + return + # unreachable + + inner() + +def outer(x: A) -> None: + def inner() -> None: + reveal_type(x) # revealed: A + raise +``` + +An arbitrary level of nesting is supported: + +```py +def f0() -> None: + x = A() + + def f1() -> None: + def f2() -> None: + def f3() -> None: + def f4() -> None: + reveal_type(x) # revealed: Unknown | A | B + f4() + f3() + f2() + f1() + + x = B() + + f1() +``` + +## At module level + +The behavior is the same if the outer scope is the global scope of a module: + +```py +def flag() -> bool: + return True + +if flag(): + x = 1 + + def f() -> None: + reveal_type(x) # revealed: Unknown | Literal[1, 2] + # Function only used inside this branch + f() + + x = 2 + + # Function only used inside this branch + f() +``` + +## Mixed declarations and bindings + +When a declaration only appears in one branch, we also consider the types of the symbol's bindings +in other branches: + +```py +def flag() -> bool: + return True + +if flag(): + A: str = "" +else: + A = None + +reveal_type(A) # revealed: Literal[""] | None + +def _(): + reveal_type(A) # revealed: str | None +``` + +This pattern appears frequently with conditional imports. The `import` statement is both a +declaration and a binding, but we still add `None` to the public type union in a situation like +this: + +```py +try: + import optional_dependency # ty: ignore +except ImportError: + optional_dependency = None + +reveal_type(optional_dependency) # revealed: Unknown | None + +def _(): + reveal_type(optional_dependency) # revealed: Unknown | None +``` + +## Limitations + +### Type narrowing + +We currently do not further analyze control flow, so we do not support cases where the inner scope +is only executed in a branch where the type of `x` is narrowed: + +```py +class A: ... + +def outer(x: A | None): + if x is not None: + def inner() -> None: + # TODO: should ideally be `A` + reveal_type(x) # revealed: A | None + inner() +``` + +### Shadowing + +Similarly, since we do not analyze control flow in the outer scope here, we assume that `inner()` +could be called between the two assignments to `x`: + +```py +def outer() -> None: + def inner() -> None: + # TODO: this should ideally be `Unknown | Literal[1]`, but no other type checker supports this either + reveal_type(x) # revealed: Unknown | None | Literal[1] + x = None + + # [additional code here] + + x = 1 + + inner() +``` + +This is currently even true if the `inner` function is only defined after the second assignment to +`x`: + +```py +def outer() -> None: + x = None + + # [additional code here] + + x = 1 + + def inner() -> None: + # TODO: this should be `Unknown | Literal[1]`. Mypy and pyright support this. + reveal_type(x) # revealed: Unknown | None | Literal[1] + inner() +``` + +A similar case derived from an ecosystem example, involving declared types: + +```py +class C: ... + +def outer(x: C | None): + x = x or C() + + reveal_type(x) # revealed: C + + def inner() -> None: + # TODO: this should ideally be `C` + reveal_type(x) # revealed: C | None + inner() +``` + +### Assignments to nonlocal variables + +Writes to the outer-scope variable are currently not detected: + +```py +def outer() -> None: + x = None + + def set_x() -> None: + nonlocal x + x = 1 + set_x() + + def inner() -> None: + # TODO: this should ideally be `Unknown | None | Literal[1]`. Mypy and pyright support this. + reveal_type(x) # revealed: Unknown | None + inner() +``` + +## Handling of overloads + +### With implementation + +Overloads need special treatment, because here, we do not want to consider *all* possible +definitions of `f`. This would otherwise result in a union of all three definitions of `f`: + +```py +from typing import overload + +@overload +def f(x: int) -> int: ... +@overload +def f(x: str) -> str: ... +def f(x: int | str) -> int | str: + raise NotImplementedError + +reveal_type(f) # revealed: Overload[(x: int) -> int, (x: str) -> str] + +def _(): + reveal_type(f) # revealed: Overload[(x: int) -> int, (x: str) -> str] +``` + +This also works if there are conflicting declarations: + +```py +def flag() -> bool: + return True + +if flag(): + @overload + def g(x: int) -> int: ... + @overload + def g(x: str) -> str: ... + def g(x: int | str) -> int | str: + return x + +else: + g: str = "" + +def _(): + reveal_type(g) # revealed: (Overload[(x: int) -> int, (x: str) -> str]) | str + +# error: [conflicting-declarations] +g = "test" +``` + +### Without an implementation + +Similarly, if there is no implementation, we only consider the last overload definition. + +```pyi +from typing import overload + +@overload +def f(x: int) -> int: ... +@overload +def f(x: str) -> str: ... + +reveal_type(f) # revealed: Overload[(x: int) -> int, (x: str) -> str] + +def _(): + reveal_type(f) # revealed: Overload[(x: int) -> int, (x: str) -> str] +``` + +This also works if there are conflicting declarations: + +```pyi +def flag() -> bool: + return True + +if flag(): + @overload + def g(x: int) -> int: ... + @overload + def g(x: str) -> str: ... +else: + g: str + +def _(): + reveal_type(g) # revealed: (Overload[(x: int) -> int, (x: str) -> str]) | str +``` + +### Overload only defined in one branch + +```py +from typing import overload + +def flag() -> bool: + return True + +if flag(): + @overload + def f(x: int) -> int: ... + @overload + def f(x: str) -> str: ... + def f(x: int | str) -> int | str: + raise NotImplementedError + + def _(): + reveal_type(f) # revealed: Overload[(x: int) -> int, (x: str) -> str] +``` diff --git a/crates/ty_python_semantic/resources/mdtest/scopes/builtin.md b/crates/ty_python_semantic/resources/mdtest/scopes/builtin.md index 54fb8880a5..0df01f7883 100644 --- a/crates/ty_python_semantic/resources/mdtest/scopes/builtin.md +++ b/crates/ty_python_semantic/resources/mdtest/scopes/builtin.md @@ -29,6 +29,8 @@ if flag(): chr: int = 1 def _(): - reveal_type(abs) # revealed: Unknown | Literal[1] | (def abs(x: SupportsAbs[_T], /) -> _T) - reveal_type(chr) # revealed: int | (def chr(i: SupportsIndex, /) -> str) + # TODO: Should ideally be `Unknown | Literal[1] | (def abs(x: SupportsAbs[_T], /) -> _T)` + reveal_type(abs) # revealed: Unknown | Literal[1] + # TODO: Should ideally be `int | (def chr(i: SupportsIndex, /) -> str)` + reveal_type(chr) # revealed: int ``` diff --git a/crates/ty_python_semantic/resources/mdtest/scopes/eager.md b/crates/ty_python_semantic/resources/mdtest/scopes/eager.md index bfd11b6677..9677c8f70d 100644 --- a/crates/ty_python_semantic/resources/mdtest/scopes/eager.md +++ b/crates/ty_python_semantic/resources/mdtest/scopes/eager.md @@ -12,7 +12,7 @@ Function definitions are evaluated lazily. x = 1 def f(): - reveal_type(x) # revealed: Unknown | Literal[2] + reveal_type(x) # revealed: Unknown | Literal[1, 2] x = 2 ``` @@ -299,7 +299,7 @@ def _(): x = 1 def f(): - # revealed: Unknown | Literal[2] + # revealed: Unknown | Literal[1, 2] [reveal_type(x) for a in range(1)] x = 2 ``` @@ -316,7 +316,7 @@ def _(): class A: def f(): - # revealed: Unknown | Literal[2] + # revealed: Unknown | Literal[1, 2] reveal_type(x) x = 2 @@ -333,7 +333,7 @@ def _(): def f(): def g(): - # revealed: Unknown | Literal[2] + # revealed: Unknown | Literal[1, 2] reveal_type(x) x = 2 ``` @@ -351,7 +351,7 @@ def _(): class A: def f(): - # revealed: Unknown | Literal[2] + # revealed: Unknown | Literal[1, 2] [reveal_type(x) for a in range(1)] x = 2 @@ -389,7 +389,7 @@ x = int class C: var: ClassVar[x] -reveal_type(C.var) # revealed: Unknown | str +reveal_type(C.var) # revealed: Unknown | int | str x = str ``` @@ -404,7 +404,8 @@ x = int class C: var: ClassVar[x] -reveal_type(C.var) # revealed: str +# TODO: should ideally be `str`, but we currently consider all reachable bindings +reveal_type(C.var) # revealed: int | str x = str ``` diff --git a/crates/ty_python_semantic/resources/mdtest/statically_known_branches.md b/crates/ty_python_semantic/resources/mdtest/statically_known_branches.md index 1417714ac2..f6d65447cf 100644 --- a/crates/ty_python_semantic/resources/mdtest/statically_known_branches.md +++ b/crates/ty_python_semantic/resources/mdtest/statically_known_branches.md @@ -1242,18 +1242,27 @@ def f() -> None: #### `if True` +`mod.py`: + ```py x: str if True: x: int +``` -def f() -> None: - reveal_type(x) # revealed: int +`main.py`: + +```py +from mod import x + +reveal_type(x) # revealed: int ``` #### `if False … else` +`mod.py`: + ```py x: str @@ -1261,13 +1270,20 @@ if False: pass else: x: int +``` -def f() -> None: - reveal_type(x) # revealed: int +`main.py`: + +```py +from mod import x + +reveal_type(x) # revealed: int ``` ### Ambiguous +`mod.py`: + ```py def flag() -> bool: return True @@ -1276,9 +1292,14 @@ x: str if flag(): x: int +``` -def f() -> None: - reveal_type(x) # revealed: str | int +`main.py`: + +```py +from mod import x + +reveal_type(x) # revealed: str | int ``` ## Conditional function definitions @@ -1478,6 +1499,8 @@ if False: ```py # error: [unresolved-import] from module import symbol + +reveal_type(symbol) # revealed: Unknown ``` #### Always true, bound diff --git a/crates/ty_python_semantic/resources/mdtest/terminal_statements.md b/crates/ty_python_semantic/resources/mdtest/terminal_statements.md index 016ab3d655..1a9f168438 100644 --- a/crates/ty_python_semantic/resources/mdtest/terminal_statements.md +++ b/crates/ty_python_semantic/resources/mdtest/terminal_statements.md @@ -575,20 +575,18 @@ def f(): Free references inside of a function body refer to variables defined in the containing scope. Function bodies are _lazy scopes_: at runtime, these references are not resolved immediately at the point of the function definition. Instead, they are resolved _at the time of the call_, which means -that their values (and types) can be different for different invocations. For simplicity, we instead -resolve free references _at the end of the containing scope_. That means that in the examples below, -all of the `x` bindings should be visible to the `reveal_type`, regardless of where we place the -`return` statements. - -TODO: These currently produce the wrong results, but not because of our terminal statement support. -See [ruff#15777](https://github.com/astral-sh/ruff/issues/15777) for more details. +that their values (and types) can be different for different invocations. For simplicity, we +currently consider _all reachable bindings_ in the containing scope: ```py def top_level_return(cond1: bool, cond2: bool): x = 1 def g(): - # TODO eliminate Unknown + # TODO We could potentially eliminate `Unknown` from the union here, + # because `x` resolves to an enclosing function-like scope and there + # are no nested `nonlocal` declarations of that symbol that might + # modify it. reveal_type(x) # revealed: Unknown | Literal[1, 2, 3] if cond1: if cond2: @@ -601,8 +599,7 @@ def return_from_if(cond1: bool, cond2: bool): x = 1 def g(): - # TODO: Literal[1, 2, 3] - reveal_type(x) # revealed: Unknown | Literal[1] + reveal_type(x) # revealed: Unknown | Literal[1, 2, 3] if cond1: if cond2: x = 2 @@ -614,8 +611,7 @@ def return_from_nested_if(cond1: bool, cond2: bool): x = 1 def g(): - # TODO: Literal[1, 2, 3] - reveal_type(x) # revealed: Unknown | Literal[1, 3] + reveal_type(x) # revealed: Unknown | Literal[1, 2, 3] if cond1: if cond2: x = 2 diff --git a/crates/ty_python_semantic/resources/mdtest/unreachable.md b/crates/ty_python_semantic/resources/mdtest/unreachable.md index d5cf7b53fc..2d01b4dbbf 100644 --- a/crates/ty_python_semantic/resources/mdtest/unreachable.md +++ b/crates/ty_python_semantic/resources/mdtest/unreachable.md @@ -241,16 +241,16 @@ def f(): ### Use of variable in nested function -In the example below, since we use `x` in the `inner` function, we use the "public" type of `x`, -which currently refers to the end-of-scope type of `x`. Since the end of the `outer` scope is -unreachable, we need to make sure that we do not emit an `unresolved-reference` diagnostic: +This is a regression test for a behavior that previously caused problems when the public type still +referred to the end-of-scope, which would result in an unresolved-reference error here since the end +of the scope is unreachable. ```py def outer(): x = 1 def inner(): - reveal_type(x) # revealed: Unknown + reveal_type(x) # revealed: Unknown | Literal[1] while True: pass ``` diff --git a/crates/ty_python_semantic/src/place.rs b/crates/ty_python_semantic/src/place.rs index 33d89ccd77..8d2abab1b6 100644 --- a/crates/ty_python_semantic/src/place.rs +++ b/crates/ty_python_semantic/src/place.rs @@ -12,7 +12,7 @@ use crate::types::{ KnownClass, Truthiness, Type, TypeAndQualifiers, TypeQualifiers, UnionBuilder, UnionType, binding_type, declaration_type, todo_type, }; -use crate::{Db, KnownModule, Program, resolve_module}; +use crate::{Db, FxOrderSet, KnownModule, Program, resolve_module}; pub(crate) use implicit_globals::{ module_type_implicit_global_declaration, module_type_implicit_global_symbol, @@ -202,8 +202,15 @@ pub(crate) fn symbol<'db>( db: &'db dyn Db, scope: ScopeId<'db>, name: &str, + considered_definitions: ConsideredDefinitions, ) -> PlaceAndQualifiers<'db> { - symbol_impl(db, scope, name, RequiresExplicitReExport::No) + symbol_impl( + db, + scope, + name, + RequiresExplicitReExport::No, + considered_definitions, + ) } /// Infer the public type of a place (its type as seen from outside its scope) in the given @@ -212,8 +219,15 @@ pub(crate) fn place<'db>( db: &'db dyn Db, scope: ScopeId<'db>, expr: &PlaceExpr, + considered_definitions: ConsideredDefinitions, ) -> PlaceAndQualifiers<'db> { - place_impl(db, scope, expr, RequiresExplicitReExport::No) + place_impl( + db, + scope, + expr, + RequiresExplicitReExport::No, + considered_definitions, + ) } /// Infer the public type of a class symbol (its type as seen from outside its scope) in the given @@ -226,7 +240,13 @@ pub(crate) fn class_symbol<'db>( place_table(db, scope) .place_id_by_name(name) .map(|symbol| { - let symbol_and_quals = place_by_id(db, scope, symbol, RequiresExplicitReExport::No); + let symbol_and_quals = place_by_id( + db, + scope, + symbol, + RequiresExplicitReExport::No, + ConsideredDefinitions::EndOfScope, + ); if symbol_and_quals.is_class_var() { // For declared class vars we do not need to check if they have bindings, @@ -241,7 +261,7 @@ pub(crate) fn class_symbol<'db>( { // Otherwise, we need to check if the symbol has bindings let use_def = use_def_map(db, scope); - let bindings = use_def.public_bindings(symbol); + let bindings = use_def.end_of_scope_bindings(symbol); let inferred = place_from_bindings_impl(db, bindings, RequiresExplicitReExport::No); // TODO: we should not need to calculate inferred type second time. This is a temporary @@ -277,6 +297,7 @@ pub(crate) fn explicit_global_symbol<'db>( global_scope(db, file), name, RequiresExplicitReExport::No, + ConsideredDefinitions::AllReachable, ) } @@ -330,18 +351,22 @@ pub(crate) fn imported_symbol<'db>( // ignore `__getattr__`. Typeshed has a fake `__getattr__` on `types.ModuleType` to help out with // dynamic imports; we shouldn't use it for `ModuleLiteral` types where we know exactly which // module we're dealing with. - symbol_impl(db, global_scope(db, file), name, requires_explicit_reexport).or_fall_back_to( + symbol_impl( db, - || { - if name == "__getattr__" { - Place::Unbound.into() - } else if name == "__builtins__" { - Place::bound(Type::any()).into() - } else { - KnownClass::ModuleType.to_instance(db).member(db, name) - } - }, + global_scope(db, file), + name, + requires_explicit_reexport, + ConsideredDefinitions::EndOfScope, ) + .or_fall_back_to(db, || { + if name == "__getattr__" { + Place::Unbound.into() + } else if name == "__builtins__" { + Place::bound(Type::any()).into() + } else { + KnownClass::ModuleType.to_instance(db).member(db, name) + } + }) } /// Lookup the type of `symbol` in the builtins namespace. @@ -361,6 +386,7 @@ pub(crate) fn builtins_symbol<'db>(db: &'db dyn Db, symbol: &str) -> PlaceAndQua global_scope(db, file), symbol, RequiresExplicitReExport::Yes, + ConsideredDefinitions::EndOfScope, ) .or_fall_back_to(db, || { // We're looking up in the builtins namespace and not the module, so we should @@ -450,9 +476,12 @@ pub(crate) fn place_from_declarations<'db>( place_from_declarations_impl(db, declarations, RequiresExplicitReExport::No) } +pub(crate) type DeclaredTypeAndConflictingTypes<'db> = + (TypeAndQualifiers<'db>, Box>>); + /// The result of looking up a declared type from declarations; see [`place_from_declarations`]. pub(crate) type PlaceFromDeclarationsResult<'db> = - Result, (TypeAndQualifiers<'db>, Box<[Type<'db>]>)>; + Result, DeclaredTypeAndConflictingTypes<'db>>; /// A type with declaredness information, and a set of type qualifiers. /// @@ -581,6 +610,7 @@ fn place_cycle_recover<'db>( _scope: ScopeId<'db>, _place_id: ScopedPlaceId, _requires_explicit_reexport: RequiresExplicitReExport, + _considered_definitions: ConsideredDefinitions, ) -> salsa::CycleRecoveryAction> { salsa::CycleRecoveryAction::Iterate } @@ -590,6 +620,7 @@ fn place_cycle_initial<'db>( _scope: ScopeId<'db>, _place_id: ScopedPlaceId, _requires_explicit_reexport: RequiresExplicitReExport, + _considered_definitions: ConsideredDefinitions, ) -> PlaceAndQualifiers<'db> { Place::bound(Type::Never).into() } @@ -600,15 +631,25 @@ fn place_by_id<'db>( scope: ScopeId<'db>, place_id: ScopedPlaceId, requires_explicit_reexport: RequiresExplicitReExport, + considered_definitions: ConsideredDefinitions, ) -> PlaceAndQualifiers<'db> { let use_def = use_def_map(db, scope); // If the place is declared, the public type is based on declarations; otherwise, it's based // on inference from bindings. - let declarations = use_def.public_declarations(place_id); + let declarations = match considered_definitions { + ConsideredDefinitions::EndOfScope => use_def.end_of_scope_declarations(place_id), + ConsideredDefinitions::AllReachable => use_def.all_reachable_declarations(place_id), + }; + let declared = place_from_declarations_impl(db, declarations, requires_explicit_reexport); + let all_considered_bindings = || match considered_definitions { + ConsideredDefinitions::EndOfScope => use_def.end_of_scope_bindings(place_id), + ConsideredDefinitions::AllReachable => use_def.all_reachable_bindings(place_id), + }; + match declared { // Place is declared, trust the declared type Ok( @@ -622,7 +663,8 @@ fn place_by_id<'db>( place: Place::Type(declared_ty, Boundness::PossiblyUnbound), qualifiers, }) => { - let bindings = use_def.public_bindings(place_id); + let bindings = all_considered_bindings(); + let boundness_analysis = bindings.boundness_analysis; let inferred = place_from_bindings_impl(db, bindings, requires_explicit_reexport); let place = match inferred { @@ -636,7 +678,11 @@ fn place_by_id<'db>( // Place is possibly undeclared and (possibly) bound Place::Type(inferred_ty, boundness) => Place::Type( UnionType::from_elements(db, [inferred_ty, declared_ty]), - boundness, + if boundness_analysis == BoundnessAnalysis::AssumeBound { + Boundness::Bound + } else { + boundness + }, ), }; @@ -647,8 +693,15 @@ fn place_by_id<'db>( place: Place::Unbound, qualifiers: _, }) => { - let bindings = use_def.public_bindings(place_id); - let inferred = place_from_bindings_impl(db, bindings, requires_explicit_reexport); + let bindings = all_considered_bindings(); + let boundness_analysis = bindings.boundness_analysis; + let mut inferred = place_from_bindings_impl(db, bindings, requires_explicit_reexport); + + if boundness_analysis == BoundnessAnalysis::AssumeBound { + if let Place::Type(ty, Boundness::PossiblyUnbound) = inferred { + inferred = Place::Type(ty, Boundness::Bound); + } + } // `__slots__` is a symbol with special behavior in Python's runtime. It can be // modified externally, but those changes do not take effect. We therefore issue @@ -707,6 +760,7 @@ fn symbol_impl<'db>( scope: ScopeId<'db>, name: &str, requires_explicit_reexport: RequiresExplicitReExport, + considered_definitions: ConsideredDefinitions, ) -> PlaceAndQualifiers<'db> { let _span = tracing::trace_span!("symbol", ?name).entered(); @@ -726,7 +780,15 @@ fn symbol_impl<'db>( place_table(db, scope) .place_id_by_name(name) - .map(|symbol| place_by_id(db, scope, symbol, requires_explicit_reexport)) + .map(|symbol| { + place_by_id( + db, + scope, + symbol, + requires_explicit_reexport, + considered_definitions, + ) + }) .unwrap_or_default() } @@ -736,12 +798,21 @@ fn place_impl<'db>( scope: ScopeId<'db>, expr: &PlaceExpr, requires_explicit_reexport: RequiresExplicitReExport, + considered_definitions: ConsideredDefinitions, ) -> PlaceAndQualifiers<'db> { let _span = tracing::trace_span!("place", ?expr).entered(); place_table(db, scope) .place_id_by_expr(expr) - .map(|place| place_by_id(db, scope, place, requires_explicit_reexport)) + .map(|place| { + place_by_id( + db, + scope, + place, + requires_explicit_reexport, + considered_definitions, + ) + }) .unwrap_or_default() } @@ -757,6 +828,7 @@ fn place_from_bindings_impl<'db>( ) -> Place<'db> { let predicates = bindings_with_constraints.predicates; let reachability_constraints = bindings_with_constraints.reachability_constraints; + let boundness_analysis = bindings_with_constraints.boundness_analysis; let mut bindings_with_constraints = bindings_with_constraints.peekable(); let is_non_exported = |binding: Definition<'db>| { @@ -776,7 +848,7 @@ fn place_from_bindings_impl<'db>( // Evaluate this lazily because we don't always need it (for example, if there are no visible // bindings at all, we don't need it), and it can cause us to evaluate reachability constraint // expressions, which is extra work and can lead to cycles. - let unbound_reachability = || { + let unbound_visibility = || { unbound_reachability_constraint.map(|reachability_constraint| { reachability_constraints.evaluate(db, predicates, reachability_constraint) }) @@ -856,7 +928,7 @@ fn place_from_bindings_impl<'db>( // return `Never` in this case, because we will union the types of all bindings, and // `Never` will be eliminated automatically. - if unbound_reachability().is_none_or(Truthiness::is_always_false) { + if unbound_visibility().is_none_or(Truthiness::is_always_false) { return Some(Type::Never); } return None; @@ -868,21 +940,33 @@ fn place_from_bindings_impl<'db>( ); if let Some(first) = types.next() { - let boundness = match unbound_reachability() { - Some(Truthiness::AlwaysTrue) => { - unreachable!( - "If we have at least one binding, the implicit `unbound` binding should not be definitely visible" - ) - } - Some(Truthiness::AlwaysFalse) | None => Boundness::Bound, - Some(Truthiness::Ambiguous) => Boundness::PossiblyUnbound, - }; - let ty = if let Some(second) = types.next() { - UnionType::from_elements(db, [first, second].into_iter().chain(types)) + let mut builder = PublicTypeBuilder::new(db); + builder.add(first); + builder.add(second); + + for ty in types { + builder.add(ty); + } + + builder.build() } else { first }; + + let boundness = match boundness_analysis { + BoundnessAnalysis::AssumeBound => Boundness::Bound, + BoundnessAnalysis::BasedOnUnboundVisibility => match unbound_visibility() { + Some(Truthiness::AlwaysTrue) => { + unreachable!( + "If we have at least one binding, the implicit `unbound` binding should not be definitely visible" + ) + } + Some(Truthiness::AlwaysFalse) | None => Boundness::Bound, + Some(Truthiness::Ambiguous) => Boundness::PossiblyUnbound, + }, + }; + match deleted_reachability { Truthiness::AlwaysFalse => Place::Type(ty, boundness), Truthiness::AlwaysTrue => Place::Unbound, @@ -893,6 +977,118 @@ fn place_from_bindings_impl<'db>( } } +/// Accumulates types from multiple bindings or declarations, and eventually builds a +/// union type from them. +/// +/// `@overload`ed function literal types are discarded if they are immediately followed +/// by their implementation. This is to ensure that we do not merge all of them into the +/// union type. The last one will include the other overloads already. +struct PublicTypeBuilder<'db> { + db: &'db dyn Db, + queue: Option>, + builder: UnionBuilder<'db>, +} + +impl<'db> PublicTypeBuilder<'db> { + fn new(db: &'db dyn Db) -> Self { + PublicTypeBuilder { + db, + queue: None, + builder: UnionBuilder::new(db), + } + } + + fn add_to_union(&mut self, element: Type<'db>) { + self.builder.add_in_place(element); + } + + fn drain_queue(&mut self) { + if let Some(queued_element) = self.queue.take() { + self.add_to_union(queued_element); + } + } + + fn add(&mut self, element: Type<'db>) -> bool { + match element { + Type::FunctionLiteral(function) => { + if function + .literal(self.db) + .last_definition(self.db) + .is_overload(self.db) + { + self.queue = Some(element); + false + } else { + self.queue = None; + self.add_to_union(element); + true + } + } + _ => { + self.drain_queue(); + self.add_to_union(element); + true + } + } + } + + fn build(mut self) -> Type<'db> { + self.drain_queue(); + self.builder.build() + } +} + +/// Accumulates multiple (potentially conflicting) declared types and type qualifiers, +/// and eventually builds a union from them. +struct DeclaredTypeBuilder<'db> { + inner: PublicTypeBuilder<'db>, + qualifiers: TypeQualifiers, + first_type: Option>, + conflicting_types: FxOrderSet>, +} + +impl<'db> DeclaredTypeBuilder<'db> { + fn new(db: &'db dyn Db) -> Self { + DeclaredTypeBuilder { + inner: PublicTypeBuilder::new(db), + qualifiers: TypeQualifiers::empty(), + first_type: None, + conflicting_types: FxOrderSet::default(), + } + } + + fn add(&mut self, element: TypeAndQualifiers<'db>) { + let element_ty = element.inner_type(); + + if self.inner.add(element_ty) { + if let Some(first_ty) = self.first_type { + if !first_ty.is_equivalent_to(self.inner.db, element_ty) { + self.conflicting_types.insert(element_ty); + } + } else { + self.first_type = Some(element_ty); + } + } + + self.qualifiers = self.qualifiers.union(element.qualifiers()); + } + + fn build(mut self) -> DeclaredTypeAndConflictingTypes<'db> { + if !self.conflicting_types.is_empty() { + self.conflicting_types.insert_before( + 0, + self.first_type + .expect("there must be a first type if there are conflicting types"), + ); + } + + ( + TypeAndQualifiers::new(self.inner.build(), self.qualifiers), + self.conflicting_types.into_boxed_slice(), + ) + } +} + /// Implementation of [`place_from_declarations`]. /// /// ## Implementation Note @@ -905,6 +1101,7 @@ fn place_from_declarations_impl<'db>( ) -> PlaceFromDeclarationsResult<'db> { let predicates = declarations.predicates; let reachability_constraints = declarations.reachability_constraints; + let boundness_analysis = declarations.boundness_analysis; let mut declarations = declarations.peekable(); let is_non_exported = |declaration: Definition<'db>| { @@ -921,7 +1118,9 @@ fn place_from_declarations_impl<'db>( _ => Truthiness::AlwaysFalse, }; - let mut types = declarations.filter_map( + let mut all_declarations_definitely_reachable = true; + + let types = declarations.filter_map( |DeclarationWithConstraint { declaration, reachability_constraint, @@ -940,32 +1139,40 @@ fn place_from_declarations_impl<'db>( if static_reachability.is_always_false() { None } else { + all_declarations_definitely_reachable = + all_declarations_definitely_reachable && static_reachability.is_always_true(); + Some(declaration_type(db, declaration)) } }, ); - if let Some(first) = types.next() { - let mut conflicting: Vec> = vec![]; - let declared = if let Some(second) = types.next() { - let ty_first = first.inner_type(); - let mut qualifiers = first.qualifiers(); + let mut types = types.peekable(); - let mut builder = UnionBuilder::new(db).add(ty_first); - for other in std::iter::once(second).chain(types) { - let other_ty = other.inner_type(); - if !ty_first.is_equivalent_to(db, other_ty) { - conflicting.push(other_ty); + if types.peek().is_some() { + let mut builder = DeclaredTypeBuilder::new(db); + for element in types { + builder.add(element); + } + let (declared, conflicting) = builder.build(); + + if !conflicting.is_empty() { + return Err((declared, conflicting)); + } + + let boundness = match boundness_analysis { + BoundnessAnalysis::AssumeBound => { + if all_declarations_definitely_reachable { + Boundness::Bound + } else { + // For declarations, it is important to consider the possibility that they might only + // be bound in one control flow path, while the other path contains a binding. In order + // to even consider the bindings as well in `place_by_id`, we return `PossiblyUnbound` + // here. + Boundness::PossiblyUnbound } - builder = builder.add(other_ty); - qualifiers = qualifiers.union(other.qualifiers()); } - TypeAndQualifiers::new(builder.build(), qualifiers) - } else { - first - }; - if conflicting.is_empty() { - let boundness = match undeclared_reachability { + BoundnessAnalysis::BasedOnUnboundVisibility => match undeclared_reachability { Truthiness::AlwaysTrue => { unreachable!( "If we have at least one declaration, the implicit `unbound` binding should not be definitely visible" @@ -973,20 +1180,10 @@ fn place_from_declarations_impl<'db>( } Truthiness::AlwaysFalse => Boundness::Bound, Truthiness::Ambiguous => Boundness::PossiblyUnbound, - }; + }, + }; - Ok( - Place::Type(declared.inner_type(), boundness) - .with_qualifiers(declared.qualifiers()), - ) - } else { - Err(( - declared, - std::iter::once(first.inner_type()) - .chain(conflicting) - .collect(), - )) - } + Ok(Place::Type(declared.inner_type(), boundness).with_qualifiers(declared.qualifiers())) } else { Ok(Place::Unbound.into()) } @@ -1045,7 +1242,7 @@ mod implicit_globals { }; place_from_declarations( db, - use_def_map(db, module_type_scope).public_declarations(place_id), + use_def_map(db, module_type_scope).end_of_scope_declarations(place_id), ) } @@ -1165,6 +1362,48 @@ impl RequiresExplicitReExport { } } +/// Specifies which definitions should be considered when looking up a place. +/// +/// In the example below, the `EndOfScope` variant would consider the `x = 2` and `x = 3` definitions, +/// while the `AllReachable` variant would also consider the `x = 1` definition. +/// ```py +/// def _(): +/// x = 1 +/// +/// x = 2 +/// +/// if flag(): +/// x = 3 +/// ``` +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, salsa::Update)] +pub(crate) enum ConsideredDefinitions { + /// Consider only the definitions that are "live" at the end of the scope, i.e. those + /// that have not been shadowed or deleted. + EndOfScope, + /// Consider all definitions that are reachable from the start of the scope. + AllReachable, +} + +/// Specifies how the boundness of a place should be determined. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, salsa::Update)] +pub(crate) enum BoundnessAnalysis { + /// The place is always considered bound. + AssumeBound, + /// The boundness of the place is determined based on the visibility of the implicit + /// `unbound` binding. In the example below, when analyzing the visibility of the + /// `x = ` binding from the position of the end of the scope, it would be + /// `Truthiness::Ambiguous`, because it could either be visible or not, depending on the + /// `flag()` return value. This would result in a `Boundness::PossiblyUnbound` for `x`. + /// + /// ```py + /// x = + /// + /// if flag(): + /// x = 1 + /// ``` + BasedOnUnboundVisibility, +} + /// Computes a possibly-widened type `Unknown | T_inferred` from the inferred type `T_inferred` /// of a symbol, unless the type is a known-instance type (e.g. `typing.Any`) or the symbol is /// considered non-modifiable (e.g. when the symbol is `@Final`). We need this for public uses diff --git a/crates/ty_python_semantic/src/semantic_index.rs b/crates/ty_python_semantic/src/semantic_index.rs index d4f848e69d..009674ba26 100644 --- a/crates/ty_python_semantic/src/semantic_index.rs +++ b/crates/ty_python_semantic/src/semantic_index.rs @@ -116,7 +116,7 @@ pub(crate) fn attribute_assignments<'db, 's>( let place_table = index.place_table(function_scope_id); let place = place_table.place_id_by_instance_attribute_name(name)?; let use_def = &index.use_def_maps[function_scope_id]; - Some((use_def.public_bindings(place), function_scope_id)) + Some((use_def.end_of_scope_bindings(place), function_scope_id)) }) } @@ -574,7 +574,7 @@ mod tests { impl UseDefMap<'_> { fn first_public_binding(&self, symbol: ScopedPlaceId) -> Option> { - self.public_bindings(symbol) + self.end_of_scope_bindings(symbol) .find_map(|constrained_binding| constrained_binding.binding.definition()) } diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index 8ac7f91811..e216d51e60 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -237,6 +237,7 @@ use self::place_state::{ LiveDeclarationsIterator, PlaceState, ScopedDefinitionId, }; use crate::node_key::NodeKey; +use crate::place::BoundnessAnalysis; use crate::semantic_index::ast_ids::ScopedUseId; use crate::semantic_index::definition::{Definition, DefinitionState}; use crate::semantic_index::narrowing_constraints::{ @@ -251,6 +252,7 @@ use crate::semantic_index::predicate::{ use crate::semantic_index::reachability_constraints::{ ReachabilityConstraints, ReachabilityConstraintsBuilder, ScopedReachabilityConstraintId, }; +use crate::semantic_index::use_def::place_state::PreviousDefinitions; use crate::semantic_index::{EagerSnapshotResult, SemanticIndex}; use crate::types::{IntersectionBuilder, Truthiness, Type, infer_narrowing_constraint}; @@ -296,7 +298,10 @@ pub(crate) struct UseDefMap<'db> { bindings_by_declaration: FxHashMap, Bindings>, /// [`PlaceState`] visible at end of scope for each place. - public_places: IndexVec, + end_of_scope_places: IndexVec, + + /// All potentially reachable bindings and declarations, for each place. + reachable_definitions: IndexVec, /// Snapshot of bindings in this scope that can be used to resolve a reference in a nested /// eager scope. @@ -332,7 +337,10 @@ impl<'db> UseDefMap<'db> { &self, use_id: ScopedUseId, ) -> BindingWithConstraintsIterator<'_, 'db> { - self.bindings_iterator(&self.bindings_by_use[use_id]) + self.bindings_iterator( + &self.bindings_by_use[use_id], + BoundnessAnalysis::BasedOnUnboundVisibility, + ) } pub(crate) fn applicable_constraints( @@ -394,11 +402,24 @@ impl<'db> UseDefMap<'db> { .may_be_true() } - pub(crate) fn public_bindings( + pub(crate) fn end_of_scope_bindings( &self, place: ScopedPlaceId, ) -> BindingWithConstraintsIterator<'_, 'db> { - self.bindings_iterator(self.public_places[place].bindings()) + self.bindings_iterator( + self.end_of_scope_places[place].bindings(), + BoundnessAnalysis::BasedOnUnboundVisibility, + ) + } + + pub(crate) fn all_reachable_bindings( + &self, + place: ScopedPlaceId, + ) -> BindingWithConstraintsIterator<'_, 'db> { + self.bindings_iterator( + &self.reachable_definitions[place].bindings, + BoundnessAnalysis::AssumeBound, + ) } pub(crate) fn eager_snapshot( @@ -409,9 +430,9 @@ impl<'db> UseDefMap<'db> { Some(EagerSnapshot::Constraint(constraint)) => { EagerSnapshotResult::FoundConstraint(*constraint) } - Some(EagerSnapshot::Bindings(bindings)) => { - EagerSnapshotResult::FoundBindings(self.bindings_iterator(bindings)) - } + Some(EagerSnapshot::Bindings(bindings)) => EagerSnapshotResult::FoundBindings( + self.bindings_iterator(bindings, BoundnessAnalysis::BasedOnUnboundVisibility), + ), None => EagerSnapshotResult::NotFound, } } @@ -420,39 +441,53 @@ impl<'db> UseDefMap<'db> { &self, declaration: Definition<'db>, ) -> BindingWithConstraintsIterator<'_, 'db> { - self.bindings_iterator(&self.bindings_by_declaration[&declaration]) + self.bindings_iterator( + &self.bindings_by_declaration[&declaration], + BoundnessAnalysis::BasedOnUnboundVisibility, + ) } pub(crate) fn declarations_at_binding( &self, binding: Definition<'db>, ) -> DeclarationsIterator<'_, 'db> { - self.declarations_iterator(&self.declarations_by_binding[&binding]) + self.declarations_iterator( + &self.declarations_by_binding[&binding], + BoundnessAnalysis::BasedOnUnboundVisibility, + ) } - pub(crate) fn public_declarations<'map>( + pub(crate) fn end_of_scope_declarations<'map>( &'map self, place: ScopedPlaceId, ) -> DeclarationsIterator<'map, 'db> { - let declarations = self.public_places[place].declarations(); - self.declarations_iterator(declarations) + let declarations = self.end_of_scope_places[place].declarations(); + self.declarations_iterator(declarations, BoundnessAnalysis::BasedOnUnboundVisibility) } - pub(crate) fn all_public_declarations<'map>( + pub(crate) fn all_reachable_declarations( + &self, + place: ScopedPlaceId, + ) -> DeclarationsIterator<'_, 'db> { + let declarations = &self.reachable_definitions[place].declarations; + self.declarations_iterator(declarations, BoundnessAnalysis::AssumeBound) + } + + pub(crate) fn all_end_of_scope_declarations<'map>( &'map self, ) -> impl Iterator)> + 'map { - (0..self.public_places.len()) + (0..self.end_of_scope_places.len()) .map(ScopedPlaceId::from_usize) - .map(|place_id| (place_id, self.public_declarations(place_id))) + .map(|place_id| (place_id, self.end_of_scope_declarations(place_id))) } - pub(crate) fn all_public_bindings<'map>( + pub(crate) fn all_end_of_scope_bindings<'map>( &'map self, ) -> impl Iterator)> + 'map { - (0..self.public_places.len()) + (0..self.end_of_scope_places.len()) .map(ScopedPlaceId::from_usize) - .map(|place_id| (place_id, self.public_bindings(place_id))) + .map(|place_id| (place_id, self.end_of_scope_bindings(place_id))) } /// This function is intended to be called only once inside `TypeInferenceBuilder::infer_function_body`. @@ -478,12 +513,14 @@ impl<'db> UseDefMap<'db> { fn bindings_iterator<'map>( &'map self, bindings: &'map Bindings, + boundness_analysis: BoundnessAnalysis, ) -> BindingWithConstraintsIterator<'map, 'db> { BindingWithConstraintsIterator { all_definitions: &self.all_definitions, predicates: &self.predicates, narrowing_constraints: &self.narrowing_constraints, reachability_constraints: &self.reachability_constraints, + boundness_analysis, inner: bindings.iter(), } } @@ -491,11 +528,13 @@ impl<'db> UseDefMap<'db> { fn declarations_iterator<'map>( &'map self, declarations: &'map Declarations, + boundness_analysis: BoundnessAnalysis, ) -> DeclarationsIterator<'map, 'db> { DeclarationsIterator { all_definitions: &self.all_definitions, predicates: &self.predicates, reachability_constraints: &self.reachability_constraints, + boundness_analysis, inner: declarations.iter(), } } @@ -531,6 +570,7 @@ pub(crate) struct BindingWithConstraintsIterator<'map, 'db> { pub(crate) predicates: &'map Predicates<'db>, pub(crate) narrowing_constraints: &'map NarrowingConstraints, pub(crate) reachability_constraints: &'map ReachabilityConstraints, + pub(crate) boundness_analysis: BoundnessAnalysis, inner: LiveBindingsIterator<'map>, } @@ -611,6 +651,7 @@ pub(crate) struct DeclarationsIterator<'map, 'db> { all_definitions: &'map IndexVec>, pub(crate) predicates: &'map Predicates<'db>, pub(crate) reachability_constraints: &'map ReachabilityConstraints, + pub(crate) boundness_analysis: BoundnessAnalysis, inner: LiveDeclarationsIterator<'map>, } @@ -639,6 +680,12 @@ impl<'db> Iterator for DeclarationsIterator<'_, 'db> { impl std::iter::FusedIterator for DeclarationsIterator<'_, '_> {} +#[derive(Debug, PartialEq, Eq, salsa::Update)] +struct ReachableDefinitions { + bindings: Bindings, + declarations: Declarations, +} + /// A snapshot of the definitions and constraints state at a particular point in control flow. #[derive(Clone, Debug)] pub(super) struct FlowSnapshot { @@ -648,7 +695,7 @@ pub(super) struct FlowSnapshot { #[derive(Debug)] pub(super) struct UseDefMapBuilder<'db> { - /// Append-only array of [`Definition`]. + /// Append-only array of [`DefinitionState`]. all_definitions: IndexVec>, /// Builder of predicates. @@ -679,6 +726,9 @@ pub(super) struct UseDefMapBuilder<'db> { /// Currently live bindings and declarations for each place. place_states: IndexVec, + /// All potentially reachable bindings and declarations, for each place. + reachable_definitions: IndexVec, + /// Snapshots of place states in this scope that can be used to resolve a reference in a /// nested eager scope. eager_snapshots: EagerSnapshots, @@ -700,6 +750,7 @@ impl<'db> UseDefMapBuilder<'db> { declarations_by_binding: FxHashMap::default(), bindings_by_declaration: FxHashMap::default(), place_states: IndexVec::new(), + reachable_definitions: IndexVec::new(), eager_snapshots: EagerSnapshots::default(), is_class_scope, } @@ -720,6 +771,11 @@ impl<'db> UseDefMapBuilder<'db> { .place_states .push(PlaceState::undefined(self.reachability)); debug_assert_eq!(place, new_place); + let new_place = self.reachable_definitions.push(ReachableDefinitions { + bindings: Bindings::unbound(self.reachability), + declarations: Declarations::undeclared(self.reachability), + }); + debug_assert_eq!(place, new_place); } pub(super) fn record_binding( @@ -738,6 +794,14 @@ impl<'db> UseDefMapBuilder<'db> { self.is_class_scope, is_place_name, ); + + self.reachable_definitions[place].bindings.record_binding( + def_id, + self.reachability, + self.is_class_scope, + is_place_name, + PreviousDefinitions::AreKept, + ); } pub(super) fn add_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId { @@ -845,6 +909,10 @@ impl<'db> UseDefMapBuilder<'db> { self.bindings_by_declaration .insert(declaration, place_state.bindings().clone()); place_state.record_declaration(def_id, self.reachability); + + self.reachable_definitions[place] + .declarations + .record_declaration(def_id, self.reachability, PreviousDefinitions::AreKept); } pub(super) fn record_declaration_and_binding( @@ -866,6 +934,17 @@ impl<'db> UseDefMapBuilder<'db> { self.is_class_scope, is_place_name, ); + + self.reachable_definitions[place] + .declarations + .record_declaration(def_id, self.reachability, PreviousDefinitions::AreKept); + self.reachable_definitions[place].bindings.record_binding( + def_id, + self.reachability, + self.is_class_scope, + is_place_name, + PreviousDefinitions::AreKept, + ); } pub(super) fn delete_binding(&mut self, place: ScopedPlaceId, is_place_name: bool) { @@ -1000,6 +1079,7 @@ impl<'db> UseDefMapBuilder<'db> { pub(super) fn finish(mut self) -> UseDefMap<'db> { self.all_definitions.shrink_to_fit(); self.place_states.shrink_to_fit(); + self.reachable_definitions.shrink_to_fit(); self.bindings_by_use.shrink_to_fit(); self.node_reachability.shrink_to_fit(); self.declarations_by_binding.shrink_to_fit(); @@ -1013,7 +1093,8 @@ impl<'db> UseDefMapBuilder<'db> { reachability_constraints: self.reachability_constraints.build(), bindings_by_use: self.bindings_by_use, node_reachability: self.node_reachability, - public_places: self.place_states, + end_of_scope_places: self.place_states, + reachable_definitions: self.reachable_definitions, declarations_by_binding: self.declarations_by_binding, bindings_by_declaration: self.bindings_by_declaration, eager_snapshots: self.eager_snapshots, diff --git a/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs b/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs index 44857e66b9..dc10dc7ef2 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs @@ -92,8 +92,20 @@ pub(super) struct LiveDeclaration { pub(super) type LiveDeclarationsIterator<'a> = std::slice::Iter<'a, LiveDeclaration>; +#[derive(Clone, Copy, Debug)] +pub(super) enum PreviousDefinitions { + AreShadowed, + AreKept, +} + +impl PreviousDefinitions { + pub(super) fn are_shadowed(self) -> bool { + matches!(self, PreviousDefinitions::AreShadowed) + } +} + impl Declarations { - fn undeclared(reachability_constraint: ScopedReachabilityConstraintId) -> Self { + pub(super) fn undeclared(reachability_constraint: ScopedReachabilityConstraintId) -> Self { let initial_declaration = LiveDeclaration { declaration: ScopedDefinitionId::UNBOUND, reachability_constraint, @@ -104,13 +116,16 @@ impl Declarations { } /// Record a newly-encountered declaration for this place. - fn record_declaration( + pub(super) fn record_declaration( &mut self, declaration: ScopedDefinitionId, reachability_constraint: ScopedReachabilityConstraintId, + previous_definitions: PreviousDefinitions, ) { - // The new declaration replaces all previous live declaration in this path. - self.live_declarations.clear(); + if previous_definitions.are_shadowed() { + // The new declaration replaces all previous live declaration in this path. + self.live_declarations.clear(); + } self.live_declarations.push(LiveDeclaration { declaration, reachability_constraint, @@ -205,7 +220,7 @@ pub(super) struct LiveBinding { pub(super) type LiveBindingsIterator<'a> = std::slice::Iter<'a, LiveBinding>; impl Bindings { - fn unbound(reachability_constraint: ScopedReachabilityConstraintId) -> Self { + pub(super) fn unbound(reachability_constraint: ScopedReachabilityConstraintId) -> Self { let initial_binding = LiveBinding { binding: ScopedDefinitionId::UNBOUND, narrowing_constraint: ScopedNarrowingConstraint::empty(), @@ -224,6 +239,7 @@ impl Bindings { reachability_constraint: ScopedReachabilityConstraintId, is_class_scope: bool, is_place_name: bool, + previous_definitions: PreviousDefinitions, ) { // If we are in a class scope, and the unbound name binding was previously visible, but we will // now replace it, record the narrowing constraints on it: @@ -232,7 +248,9 @@ impl Bindings { } // The new binding replaces all previous live bindings in this path, and has no // constraints. - self.live_bindings.clear(); + if previous_definitions.are_shadowed() { + self.live_bindings.clear(); + } self.live_bindings.push(LiveBinding { binding, narrowing_constraint: ScopedNarrowingConstraint::empty(), @@ -349,6 +367,7 @@ impl PlaceState { reachability_constraint, is_class_scope, is_place_name, + PreviousDefinitions::AreShadowed, ); } @@ -380,8 +399,11 @@ impl PlaceState { declaration_id: ScopedDefinitionId, reachability_constraint: ScopedReachabilityConstraintId, ) { - self.declarations - .record_declaration(declaration_id, reachability_constraint); + self.declarations.record_declaration( + declaration_id, + reachability_constraint, + PreviousDefinitions::AreShadowed, + ); } /// Merge another [`PlaceState`] into this one. diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index af1a258b6e..94c2852790 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -23,7 +23,6 @@ use type_ordering::union_or_intersection_elements_ordering; pub(crate) use self::builder::{IntersectionBuilder, UnionBuilder}; pub use self::diagnostic::TypeCheckDiagnostics; pub(crate) use self::diagnostic::register_lints; -pub(crate) use self::display::TypeArrayDisplay; pub(crate) use self::infer::{ infer_deferred_types, infer_definition_types, infer_expression_type, infer_expression_types, infer_scope_types, diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 4eeb050046..e479bd020b 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1603,7 +1603,7 @@ impl<'db> ClassLiteral<'db> { let table = place_table(db, class_body_scope); let use_def = use_def_map(db, class_body_scope); - for (place_id, declarations) in use_def.all_public_declarations() { + for (place_id, declarations) in use_def.all_end_of_scope_declarations() { // Here, we exclude all declarations that are not annotated assignments. We need this because // things like function definitions and nested classes would otherwise be considered dataclass // fields. The check is too broad in the sense that it also excludes (weird) constructs where @@ -1633,7 +1633,7 @@ impl<'db> ClassLiteral<'db> { } if let Some(attr_ty) = attr.place.ignore_possibly_unbound() { - let bindings = use_def.public_bindings(place_id); + let bindings = use_def.end_of_scope_bindings(place_id); let default_ty = place_from_bindings(db, bindings).ignore_possibly_unbound(); attributes.insert(place_expr.expect_name().clone(), (attr_ty, default_ty)); @@ -1750,7 +1750,7 @@ impl<'db> ClassLiteral<'db> { let method = index.expect_single_definition(method_def); let method_place = class_table.place_id_by_name(&method_def.name).unwrap(); class_map - .public_bindings(method_place) + .end_of_scope_bindings(method_place) .find_map(|bind| { (bind.binding.is_defined_and(|def| def == method)) .then(|| class_map.is_binding_reachable(db, &bind)) @@ -1994,7 +1994,7 @@ impl<'db> ClassLiteral<'db> { if let Some(place_id) = table.place_id_by_name(name) { let use_def = use_def_map(db, body_scope); - let declarations = use_def.public_declarations(place_id); + let declarations = use_def.end_of_scope_declarations(place_id); let declared_and_qualifiers = place_from_declarations(db, declarations); match declared_and_qualifiers { Ok(PlaceAndQualifiers { @@ -2009,7 +2009,7 @@ impl<'db> ClassLiteral<'db> { // The attribute is declared in the class body. - let bindings = use_def.public_bindings(place_id); + let bindings = use_def.end_of_scope_bindings(place_id); let inferred = place_from_bindings(db, bindings); let has_binding = !inferred.is_unbound(); diff --git a/crates/ty_python_semantic/src/types/ide_support.rs b/crates/ty_python_semantic/src/types/ide_support.rs index c668d1f0eb..7cf05fa57f 100644 --- a/crates/ty_python_semantic/src/types/ide_support.rs +++ b/crates/ty_python_semantic/src/types/ide_support.rs @@ -16,7 +16,7 @@ pub(crate) fn all_declarations_and_bindings<'db>( let table = place_table(db, scope_id); use_def_map - .all_public_declarations() + .all_end_of_scope_declarations() .filter_map(move |(symbol_id, declarations)| { place_from_declarations(db, declarations) .ok() @@ -29,7 +29,7 @@ pub(crate) fn all_declarations_and_bindings<'db>( }) .chain( use_def_map - .all_public_bindings() + .all_end_of_scope_bindings() .filter_map(move |(symbol_id, bindings)| { place_from_bindings(db, bindings) .ignore_possibly_unbound() @@ -140,7 +140,7 @@ impl AllMembers { let use_def_map = use_def_map(db, module_scope); let place_table = place_table(db, module_scope); - for (symbol_id, _) in use_def_map.all_public_declarations() { + for (symbol_id, _) in use_def_map.all_end_of_scope_declarations() { let Some(symbol_name) = place_table.place_expr(symbol_id).as_name() else { continue; }; diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 022804b1b7..e66d7aa119 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -49,10 +49,10 @@ use crate::module_name::{ModuleName, ModuleNameResolutionError}; use crate::module_resolver::resolve_module; use crate::node_key::NodeKey; use crate::place::{ - Boundness, LookupError, Place, PlaceAndQualifiers, builtins_module_scope, builtins_symbol, - explicit_global_symbol, global_symbol, module_type_implicit_global_declaration, - module_type_implicit_global_symbol, place, place_from_bindings, place_from_declarations, - typing_extensions_symbol, + Boundness, ConsideredDefinitions, LookupError, Place, PlaceAndQualifiers, + builtins_module_scope, builtins_symbol, explicit_global_symbol, global_symbol, + module_type_implicit_global_declaration, module_type_implicit_global_symbol, place, + place_from_bindings, place_from_declarations, typing_extensions_symbol, }; use crate::semantic_index::ast_ids::{ HasScopedExpressionId, HasScopedUseId, ScopedExpressionId, ScopedUseId, @@ -101,9 +101,9 @@ use crate::types::{ IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, LintDiagnosticGuard, MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType, StringLiteralType, SubclassOfType, Truthiness, Type, - TypeAliasType, TypeAndQualifiers, TypeArrayDisplay, TypeIsType, TypeQualifiers, - TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypeVarVariance, UnionBuilder, - UnionType, binding_type, todo_type, + TypeAliasType, TypeAndQualifiers, TypeIsType, TypeQualifiers, TypeVarBoundOrConstraints, + TypeVarInstance, TypeVarKind, TypeVarVariance, UnionBuilder, UnionType, binding_type, + todo_type, }; use crate::unpack::{Unpack, UnpackPosition}; use crate::util::subscript::{PyIndex, PySlice}; @@ -1208,7 +1208,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { for place in overloaded_function_places { if let Place::Type(Type::FunctionLiteral(function), Boundness::Bound) = - place_from_bindings(self.db(), use_def.public_bindings(place)) + place_from_bindings(self.db(), use_def.end_of_scope_bindings(place)) { if function.file(self.db()) != self.file() { // If the function is not in this file, we don't need to check it. @@ -1579,7 +1579,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .place_table(FileScopeId::global()) .place_id_by_expr(&place.expr) { - Some(id) => global_use_def_map.public_declarations(id), + Some(id) => global_use_def_map.end_of_scope_declarations(id), // This case is a syntax error (load before global declaration) but ignore that here None => use_def.declarations_at_binding(binding), } @@ -1643,7 +1643,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if let Some(builder) = self.context.report_lint(&CONFLICTING_DECLARATIONS, node) { builder.into_diagnostic(format_args!( "Conflicting declared types for `{place}`: {}", - conflicting.display(db) + conflicting.iter().map(|ty| ty.display(db)).join(", ") )); } ty.inner_type() @@ -5663,7 +5663,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // If we're inferring types of deferred expressions, always treat them as public symbols if self.is_deferred() { let place = if let Some(place_id) = place_table.place_id_by_expr(expr) { - place_from_bindings(db, use_def.public_bindings(place_id)) + place_from_bindings(db, use_def.end_of_scope_bindings(place_id)) } else { assert!( self.deferred_state.in_string_annotation(), @@ -5818,9 +5818,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { for enclosing_root_place in enclosing_place_table.root_place_exprs(expr) { if enclosing_root_place.is_bound() { - if let Place::Type(_, _) = - place(db, enclosing_scope_id, &enclosing_root_place.expr) - .place + if let Place::Type(_, _) = place( + db, + enclosing_scope_id, + &enclosing_root_place.expr, + ConsideredDefinitions::AllReachable, + ) + .place { return Place::Unbound.into(); } @@ -5846,7 +5850,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // runtime, it is the scope that creates the cell for our closure.) If the name // isn't bound in that scope, we should get an unbound name, not continue // falling back to other scopes / globals / builtins. - return place(db, enclosing_scope_id, expr).map_type(|ty| { + return place( + db, + enclosing_scope_id, + expr, + ConsideredDefinitions::AllReachable, + ) + .map_type(|ty| { self.narrow_place_with_applicable_constraints(expr, ty, &constraint_keys) }); } @@ -9884,7 +9894,7 @@ mod tests { assert_eq!(scope.name(db, &module), *expected_scope_name); } - symbol(db, scope, symbol_name).place + symbol(db, scope, symbol_name, ConsideredDefinitions::EndOfScope).place } #[track_caller] @@ -10129,7 +10139,7 @@ mod tests { fn first_public_binding<'db>(db: &'db TestDb, file: File, name: &str) -> Definition<'db> { let scope = global_scope(db, file); use_def_map(db, scope) - .public_bindings(place_table(db, scope).place_id_by_name(name).unwrap()) + .end_of_scope_bindings(place_table(db, scope).place_id_by_name(name).unwrap()) .find_map(|b| b.binding.definition()) .expect("no binding found") } diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index c091d8adcb..db8d344d50 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -5,12 +5,12 @@ use itertools::{Either, Itertools}; use ruff_python_ast::name::Name; use crate::{ + Db, FxOrderSet, place::{place_from_bindings, place_from_declarations}, semantic_index::{place_table, use_def_map}, types::{ ClassBase, ClassLiteral, KnownFunction, Type, TypeMapping, TypeQualifiers, TypeVarInstance, }, - {Db, FxOrderSet}, }; use super::TypeVarVariance; @@ -345,7 +345,7 @@ fn cached_protocol_interface<'db>( members.extend( use_def_map - .all_public_declarations() + .all_end_of_scope_declarations() .flat_map(|(place_id, declarations)| { place_from_declarations(db, declarations).map(|place| (place_id, place)) }) @@ -363,15 +363,13 @@ fn cached_protocol_interface<'db>( // members at runtime, and it's important that we accurately understand // type narrowing that uses `isinstance()` or `issubclass()` with // runtime-checkable protocols. - .chain( - use_def_map - .all_public_bindings() - .filter_map(|(place_id, bindings)| { - place_from_bindings(db, bindings) - .ignore_possibly_unbound() - .map(|ty| (place_id, ty, TypeQualifiers::default())) - }), - ) + .chain(use_def_map.all_end_of_scope_bindings().filter_map( + |(place_id, bindings)| { + place_from_bindings(db, bindings) + .ignore_possibly_unbound() + .map(|ty| (place_id, ty, TypeQualifiers::default())) + }, + )) .filter_map(|(place_id, member, qualifiers)| { Some(( place_table.place_expr(place_id).as_name()?, diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 2729e64a8a..473781f18d 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -1729,10 +1729,10 @@ mod tests { }; assert_eq!(a_name, "a"); assert_eq!(b_name, "b"); - // TODO resolution should not be deferred; we should see A not B + // TODO resolution should not be deferred; we should see A, not A | B assert_eq!( a_annotated_ty.unwrap().display(&db).to_string(), - "Unknown | B" + "Unknown | A | B" ); assert_eq!(b_annotated_ty.unwrap().display(&db).to_string(), "T"); } @@ -1777,8 +1777,8 @@ mod tests { }; assert_eq!(a_name, "a"); assert_eq!(b_name, "b"); - // Parameter resolution deferred; we should see B - assert_eq!(a_annotated_ty.unwrap().display(&db).to_string(), "B"); + // Parameter resolution deferred: + assert_eq!(a_annotated_ty.unwrap().display(&db).to_string(), "A | B"); assert_eq!(b_annotated_ty.unwrap().display(&db).to_string(), "T"); }