[ty] Support narrowing on isinstance()/issubclass() if the second argument is a dynamic, intersection, union or typevar type (#18900)

This commit is contained in:
Alex Waygood 2025-06-24 11:55:26 +01:00 committed by GitHub
parent fd2cc37f90
commit 27eee5a1a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 246 additions and 9 deletions

View file

@ -211,3 +211,100 @@ def f(
else:
reveal_type(d) # revealed: P & ~AlwaysFalsy
```
## Narrowing if an object of type `Any` or `Unknown` is used as the second argument
In order to preserve the gradual guarantee, we intersect with the type of the second argument if the
type of the second argument is a dynamic type:
```py
from typing import Any
from something_unresolvable import SomethingUnknown # error: [unresolved-import]
class Foo: ...
def f(a: Foo, b: Any):
if isinstance(a, SomethingUnknown):
reveal_type(a) # revealed: Foo & Unknown
if isinstance(a, b):
reveal_type(a) # revealed: Foo & Any
```
## Narrowing if an object with an intersection/union/TypeVar type is used as the second argument
If an intersection with only positive members is used as the second argument, and all positive
members of the intersection are valid arguments for the second argument to `isinstance()`, we
intersect with each positive member of the intersection:
```toml
[environment]
python-version = "3.12"
```
```py
from typing import Any
from ty_extensions import Intersection
class Foo: ...
class Bar:
attribute: int
class Baz:
attribute: str
def f(x: Foo, y: Intersection[type[Bar], type[Baz]], z: type[Any]):
if isinstance(x, y):
reveal_type(x) # revealed: Foo & Bar & Baz
if isinstance(x, z):
reveal_type(x) # revealed: Foo & Any
```
The same if a union type is used:
```py
def g(x: Foo, y: type[Bar | Baz]):
if isinstance(x, y):
reveal_type(x) # revealed: (Foo & Bar) | (Foo & Baz)
```
And even if a `TypeVar` is used, providing it has valid upper bounds/constraints:
```py
from typing import TypeVar
T = TypeVar("T", bound=type[Bar])
def h_old_syntax(x: Foo, y: T) -> T:
if isinstance(x, y):
reveal_type(x) # revealed: Foo & Bar
reveal_type(x.attribute) # revealed: int
return y
def h[U: type[Bar | Baz]](x: Foo, y: U) -> U:
if isinstance(x, y):
reveal_type(x) # revealed: (Foo & Bar) | (Foo & Baz)
reveal_type(x.attribute) # revealed: int | str
return y
```
Or even a tuple of tuple of typevars that have intersection bounds...
```py
from ty_extensions import Intersection
class Spam: ...
class Eggs: ...
class Ham: ...
class Mushrooms: ...
def i[T: Intersection[type[Bar], type[Baz | Spam]], U: (type[Eggs], type[Ham])](x: Foo, y: T, z: U) -> tuple[T, U]:
if isinstance(x, (y, (z, Mushrooms))):
reveal_type(x) # revealed: (Foo & Bar & Baz) | (Foo & Bar & Spam) | (Foo & Eggs) | (Foo & Ham) | (Foo & Mushrooms)
return (y, z)
```

View file

@ -278,3 +278,82 @@ def _(x: type[UsesMeta1], y: type[UsesMeta2]):
else:
reveal_type(y) # revealed: type[UsesMeta2]
```
## Narrowing if an object with an intersection/union/TypeVar type is used as the second argument
If an intersection with only positive members is used as the second argument, and all positive
members of the intersection are valid arguments for the second argument to `isinstance()`, we
intersect with each positive member of the intersection:
```toml
[environment]
python-version = "3.12"
```
```py
from typing import Any, ClassVar
from ty_extensions import Intersection
class Foo: ...
class Bar:
attribute: ClassVar[int]
class Baz:
attribute: ClassVar[str]
def f(x: type[Foo], y: Intersection[type[Bar], type[Baz]], z: type[Any]):
if issubclass(x, y):
reveal_type(x) # revealed: type[Foo] & type[Bar] & type[Baz]
if issubclass(x, z):
reveal_type(x) # revealed: type[Foo] & Any
```
The same if a union type is used:
```py
def g(x: type[Foo], y: type[Bar | Baz]):
if issubclass(x, y):
reveal_type(x) # revealed: (type[Foo] & type[Bar]) | (type[Foo] & type[Baz])
```
And even if a `TypeVar` is used, providing it has valid upper bounds/constraints:
```py
from typing import TypeVar
T = TypeVar("T", bound=type[Bar])
def h_old_syntax(x: type[Foo], y: T) -> T:
if issubclass(x, y):
reveal_type(x) # revealed: type[Foo] & type[Bar]
reveal_type(x.attribute) # revealed: int
return y
def h[U: type[Bar | Baz]](x: type[Foo], y: U) -> U:
if issubclass(x, y):
reveal_type(x) # revealed: (type[Foo] & type[Bar]) | (type[Foo] & type[Baz])
reveal_type(x.attribute) # revealed: int | str
return y
```
Or even a tuple of tuple of typevars that have intersection bounds...
```py
from ty_extensions import Intersection
class Spam: ...
class Eggs: ...
class Ham: ...
class Mushrooms: ...
def i[T: Intersection[type[Bar], type[Baz | Spam]], U: (type[Eggs], type[Ham])](x: type[Foo], y: T, z: U) -> tuple[T, U]:
if issubclass(x, (y, (z, Mushrooms))):
# revealed: (type[Foo] & type[Bar] & type[Baz]) | (type[Foo] & type[Bar] & type[Spam]) | (type[Foo] & type[Eggs]) | (type[Foo] & type[Ham]) | (type[Foo] & type[Mushrooms])
reveal_type(x)
return (y, z)
```

View file

@ -9,8 +9,8 @@ use crate::semantic_index::predicate::{
use crate::types::function::KnownFunction;
use crate::types::infer::infer_same_file_expression_type;
use crate::types::{
IntersectionBuilder, KnownClass, SubclassOfType, Truthiness, Type, UnionBuilder,
infer_expression_types,
ClassLiteral, ClassType, IntersectionBuilder, KnownClass, SubclassOfInner, SubclassOfType,
Truthiness, Type, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types,
};
use ruff_db::parsed::{ParsedModuleRef, parsed_module};
@ -167,9 +167,13 @@ impl ClassInfoConstraintFunction {
/// The `classinfo` argument can be a class literal, a tuple of (tuples of) class literals. PEP 604
/// union types are not yet supported. Returns `None` if the `classinfo` argument has a wrong type.
fn generate_constraint<'db>(self, db: &'db dyn Db, classinfo: Type<'db>) -> Option<Type<'db>> {
let constraint_fn = |class| match self {
ClassInfoConstraintFunction::IsInstance => Type::instance(db, class),
ClassInfoConstraintFunction::IsSubclass => SubclassOfType::from(db, class),
let constraint_fn = |class: ClassLiteral<'db>| match self {
ClassInfoConstraintFunction::IsInstance => {
Type::instance(db, class.default_specialization(db))
}
ClassInfoConstraintFunction::IsSubclass => {
SubclassOfType::from(db, class.default_specialization(db))
}
};
match classinfo {
@ -186,13 +190,70 @@ impl ClassInfoConstraintFunction {
if class_literal.is_known(db, KnownClass::Any) {
None
} else {
Some(constraint_fn(class_literal.default_specialization(db)))
Some(constraint_fn(class_literal))
}
}
Type::SubclassOf(subclass_of_ty) => {
subclass_of_ty.subclass_of().into_class().map(constraint_fn)
Type::SubclassOf(subclass_of_ty) => match subclass_of_ty.subclass_of() {
SubclassOfInner::Class(ClassType::NonGeneric(class)) => Some(constraint_fn(class)),
// It's not valid to use a generic alias as the second argument to `isinstance()` or `issubclass()`,
// e.g. `isinstance(x, list[int])` fails at runtime.
SubclassOfInner::Class(ClassType::Generic(_)) => None,
SubclassOfInner::Dynamic(dynamic) => Some(Type::Dynamic(dynamic)),
},
Type::Dynamic(_) => Some(classinfo),
Type::Intersection(intersection) => {
if intersection.negative(db).is_empty() {
let mut builder = IntersectionBuilder::new(db);
for element in intersection.positive(db) {
builder = builder.add_positive(self.generate_constraint(db, *element)?);
}
Some(builder.build())
} else {
// TODO: can we do better here?
None
}
}
_ => None,
Type::Union(union) => {
let mut builder = UnionBuilder::new(db);
for element in union.elements(db) {
builder = builder.add(self.generate_constraint(db, *element)?);
}
Some(builder.build())
}
Type::TypeVar(type_var) => match type_var.bound_or_constraints(db)? {
TypeVarBoundOrConstraints::UpperBound(bound) => self.generate_constraint(db, bound),
TypeVarBoundOrConstraints::Constraints(constraints) => {
self.generate_constraint(db, Type::Union(constraints))
}
},
// It's not valid to use a generic alias as the second argument to `isinstance()` or `issubclass()`,
// e.g. `isinstance(x, list[int])` fails at runtime.
Type::GenericAlias(_) => None,
Type::AlwaysFalsy
| Type::AlwaysTruthy
| Type::BooleanLiteral(_)
| Type::BoundMethod(_)
| Type::BoundSuper(_)
| Type::BytesLiteral(_)
| Type::Callable(_)
| Type::DataclassDecorator(_)
| Type::Never
| Type::MethodWrapper(_)
| Type::ModuleLiteral(_)
| Type::FunctionLiteral(_)
| Type::ProtocolInstance(_)
| Type::PropertyInstance(_)
| Type::SpecialForm(_)
| Type::NominalInstance(_)
| Type::LiteralString
| Type::StringLiteral(_)
| Type::IntLiteral(_)
| Type::KnownInstance(_)
| Type::TypeIs(_)
| Type::WrapperDescriptor(_)
| Type::DataclassTransformer(_) => None,
}
}
}