mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-30 13:51:37 +00:00
[red-knot] Fix bug where union of two iterable types was not recognised as iterable (#13992)
This commit is contained in:
parent
1607d88c22
commit
42c70697d8
2 changed files with 81 additions and 5 deletions
|
@ -144,3 +144,82 @@ class NotIterable:
|
||||||
for x in NotIterable(): # error: "Object of type `NotIterable` is not iterable"
|
for x in NotIterable(): # error: "Object of type `NotIterable` is not iterable"
|
||||||
pass
|
pass
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Union type as iterable
|
||||||
|
|
||||||
|
```py
|
||||||
|
class TestIter:
|
||||||
|
def __next__(self) -> int:
|
||||||
|
return 42
|
||||||
|
|
||||||
|
class Test:
|
||||||
|
def __iter__(self) -> TestIter:
|
||||||
|
return TestIter()
|
||||||
|
|
||||||
|
class Test2:
|
||||||
|
def __iter__(self) -> TestIter:
|
||||||
|
return TestIter()
|
||||||
|
|
||||||
|
def bool_instance() -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
flag = bool_instance()
|
||||||
|
|
||||||
|
for x in Test() if flag else Test2():
|
||||||
|
reveal_type(x) # revealed: int
|
||||||
|
```
|
||||||
|
|
||||||
|
## Union type as iterator
|
||||||
|
|
||||||
|
```py
|
||||||
|
class TestIter:
|
||||||
|
def __next__(self) -> int:
|
||||||
|
return 42
|
||||||
|
|
||||||
|
class TestIter2:
|
||||||
|
def __next__(self) -> int:
|
||||||
|
return 42
|
||||||
|
|
||||||
|
class Test:
|
||||||
|
def __iter__(self) -> TestIter | TestIter2:
|
||||||
|
return TestIter()
|
||||||
|
|
||||||
|
for x in Test():
|
||||||
|
reveal_type(x) # revealed: int
|
||||||
|
```
|
||||||
|
|
||||||
|
## Union type as iterable and union type as iterator
|
||||||
|
|
||||||
|
```py
|
||||||
|
class TestIter:
|
||||||
|
def __next__(self) -> int | Exception:
|
||||||
|
return 42
|
||||||
|
|
||||||
|
class TestIter2:
|
||||||
|
def __next__(self) -> str | tuple[int, int]:
|
||||||
|
return "42"
|
||||||
|
|
||||||
|
class TestIter3:
|
||||||
|
def __next__(self) -> bytes:
|
||||||
|
return b"42"
|
||||||
|
|
||||||
|
class TestIter4:
|
||||||
|
def __next__(self) -> memoryview:
|
||||||
|
return memoryview(b"42")
|
||||||
|
|
||||||
|
class Test:
|
||||||
|
def __iter__(self) -> TestIter | TestIter2:
|
||||||
|
return TestIter()
|
||||||
|
|
||||||
|
class Test2:
|
||||||
|
def __iter__(self) -> TestIter3 | TestIter4:
|
||||||
|
return TestIter3()
|
||||||
|
|
||||||
|
def bool_instance() -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
flag = bool_instance()
|
||||||
|
|
||||||
|
for x in Test() if flag else Test2():
|
||||||
|
reveal_type(x) # revealed: int | Exception | str | tuple[int, int] | bytes | memoryview
|
||||||
|
```
|
||||||
|
|
|
@ -1104,10 +1104,7 @@ impl<'db> Type<'db> {
|
||||||
|
|
||||||
let dunder_iter_method = iterable_meta_type.member(db, "__iter__");
|
let dunder_iter_method = iterable_meta_type.member(db, "__iter__");
|
||||||
if !dunder_iter_method.is_unbound() {
|
if !dunder_iter_method.is_unbound() {
|
||||||
let CallOutcome::Callable {
|
let Some(iterator_ty) = dunder_iter_method.call(db, &[self]).return_ty(db) else {
|
||||||
return_ty: iterator_ty,
|
|
||||||
} = dunder_iter_method.call(db, &[self])
|
|
||||||
else {
|
|
||||||
return IterationOutcome::NotIterable {
|
return IterationOutcome::NotIterable {
|
||||||
not_iterable_ty: self,
|
not_iterable_ty: self,
|
||||||
};
|
};
|
||||||
|
@ -1115,7 +1112,7 @@ impl<'db> Type<'db> {
|
||||||
|
|
||||||
let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
|
let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
|
||||||
return dunder_next_method
|
return dunder_next_method
|
||||||
.call(db, &[self])
|
.call(db, &[iterator_ty])
|
||||||
.return_ty(db)
|
.return_ty(db)
|
||||||
.map(|element_ty| IterationOutcome::Iterable { element_ty })
|
.map(|element_ty| IterationOutcome::Iterable { element_ty })
|
||||||
.unwrap_or(IterationOutcome::NotIterable {
|
.unwrap_or(IterationOutcome::NotIterable {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue