[red-knot] Fix bug where union of two iterable types was not recognised as iterable (#13992)

This commit is contained in:
Alex Waygood 2024-10-30 11:54:16 +00:00 committed by GitHub
parent 1607d88c22
commit 42c70697d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 81 additions and 5 deletions

View file

@ -144,3 +144,82 @@ class NotIterable:
for x in NotIterable(): # error: "Object of type `NotIterable` is not iterable"
pass
```
## Union type as iterable
```py
class TestIter:
def __next__(self) -> int:
return 42
class Test:
def __iter__(self) -> TestIter:
return TestIter()
class Test2:
def __iter__(self) -> TestIter:
return TestIter()
def bool_instance() -> bool:
return True
flag = bool_instance()
for x in Test() if flag else Test2():
reveal_type(x) # revealed: int
```
## Union type as iterator
```py
class TestIter:
def __next__(self) -> int:
return 42
class TestIter2:
def __next__(self) -> int:
return 42
class Test:
def __iter__(self) -> TestIter | TestIter2:
return TestIter()
for x in Test():
reveal_type(x) # revealed: int
```
## Union type as iterable and union type as iterator
```py
class TestIter:
def __next__(self) -> int | Exception:
return 42
class TestIter2:
def __next__(self) -> str | tuple[int, int]:
return "42"
class TestIter3:
def __next__(self) -> bytes:
return b"42"
class TestIter4:
def __next__(self) -> memoryview:
return memoryview(b"42")
class Test:
def __iter__(self) -> TestIter | TestIter2:
return TestIter()
class Test2:
def __iter__(self) -> TestIter3 | TestIter4:
return TestIter3()
def bool_instance() -> bool:
return True
flag = bool_instance()
for x in Test() if flag else Test2():
reveal_type(x) # revealed: int | Exception | str | tuple[int, int] | bytes | memoryview
```

View file

@ -1104,10 +1104,7 @@ impl<'db> Type<'db> {
let dunder_iter_method = iterable_meta_type.member(db, "__iter__");
if !dunder_iter_method.is_unbound() {
let CallOutcome::Callable {
return_ty: iterator_ty,
} = dunder_iter_method.call(db, &[self])
else {
let Some(iterator_ty) = dunder_iter_method.call(db, &[self]).return_ty(db) else {
return IterationOutcome::NotIterable {
not_iterable_ty: self,
};
@ -1115,7 +1112,7 @@ impl<'db> Type<'db> {
let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
return dunder_next_method
.call(db, &[self])
.call(db, &[iterator_ty])
.return_ty(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {