diff --git a/crates/red_knot_python_semantic/resources/mdtest/loops/for_loop.md b/crates/red_knot_python_semantic/resources/mdtest/loops/for_loop.md index d2e30b0f52..58dfe0ad51 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/loops/for_loop.md +++ b/crates/red_knot_python_semantic/resources/mdtest/loops/for_loop.md @@ -223,3 +223,61 @@ flag = bool_instance() for x in Test() if flag else Test2(): reveal_type(x) # revealed: int | Exception | str | tuple[int, int] | bytes | memoryview ``` + +## Union type as iterable where one union element has no `__iter__` method + +```py +class TestIter: + def __next__(self) -> int: + return 42 + +class Test: + def __iter__(self) -> TestIter: + return TestIter() + +def coinflip() -> bool: + return True + +# TODO: we should emit a diagnostic here (it might not be iterable) +for x in Test() if coinflip() else 42: + reveal_type(x) # revealed: int | Unknown +``` + +## Union type as iterable where one union element has invalid `__iter__` method + +```py +class TestIter: + def __next__(self) -> int: + return 42 + +class Test: + def __iter__(self) -> TestIter: + return TestIter() + +class Test2: + def __iter__(self) -> int: + return 42 + +def coinflip() -> bool: + return True + +# TODO: we should emit a diagnostic here (it might not be iterable) +for x in Test() if coinflip() else Test2(): + reveal_type(x) # revealed: int | Unknown +``` + +## Union type as iterator where one union element has no `__next__` method + +```py +class TestIter: + def __next__(self) -> int: + return 42 + +class Test: + def __iter__(self) -> TestIter | int: + return TestIter() + +# TODO: we should emit a diagnostic here (it might not be iterable) +for x in Test(): + reveal_type(x) # revealed: int | Unknown +``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 459b17b036..e3b9e78182 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1090,7 +1090,7 @@ impl<'db> Type<'db> { }; } - if let Type::Unknown | Type::Any = self { + if matches!(self, Type::Unknown | Type::Any | Type::Todo) { // Explicit handling of `Unknown` and `Any` necessary until `type[Unknown]` and // `type[Any]` are not defined as `Todo` anymore. return IterationOutcome::Iterable { element_ty: self }; @@ -1185,9 +1185,9 @@ impl<'db> Type<'db> { // TODO can we do better here? `type[LiteralString]`? Type::StringLiteral(_) | Type::LiteralString => KnownClass::Str.to_class(db), // TODO: `type[Any]`? - Type::Any => Type::Todo, + Type::Any => Type::Any, // TODO: `type[Unknown]`? - Type::Unknown => Type::Todo, + Type::Unknown => Type::Unknown, // TODO intersections Type::Intersection(_) => Type::Todo, Type::Todo => Type::Todo,