diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md b/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md index 91d55f7352..c4b475da1b 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md @@ -74,7 +74,9 @@ def _( def bar() -> None: return None -async def baz(): ... +async def baz() -> int: + return 42 + async def outer(): # avoid unrelated syntax errors on yield, yield from, and await def _( a: 1, # error: [invalid-type-form] "Int literals are not allowed in this context in a type expression" diff --git a/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md b/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md index bb722884f3..48c31d7830 100644 --- a/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md +++ b/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md @@ -131,7 +131,8 @@ match obj: ```py class C: - def __await__(self): ... + def __await__(self): + yield # error: [invalid-syntax] "`return` statement outside of a function" return @@ -147,6 +148,8 @@ yield from [] await C() def f(): + # TODO: no error, C is awaitable + # error: [invalid-await] "`C` is not awaitable" # error: [invalid-syntax] "`await` outside of an asynchronous function" await C() ``` diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 1712d82b26..fd1626e3a1 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -462,7 +462,7 @@ reveal_type(D().h(1)) # revealed: Literal[2] | Unknown reveal_type(C().h(True)) # revealed: Literal[True] reveal_type(D().h(True)) # revealed: Literal[2] | Unknown reveal_type(C().i(1)) # revealed: list[Literal[1]] -reveal_type(D().i(1)) # revealed: list[Unknown] +reveal_type(D().i(1)) # revealed: list[@Todo(list literal element type)] class F: def f(self) -> Literal[1, 2]: diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 2e93c2a31e..2c80ad223e 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -639,7 +639,25 @@ impl<'db> ScopeInference<'db> { } } } - union.build() + + let module = parsed_module(db, self.scope.file(db)).load(db); + if self + .scope + .node(db) + .as_function(&module) + .is_some_and(|func| { + let index = semantic_index(db, self.scope.file(db)); + let is_generator = self.scope.file_scope_id(db).is_generator_function(index); + + func.is_async && !is_generator + }) + { + // TODO: yield/await type inference + KnownClass::CoroutineType + .to_specialized_instance(db, [Type::any(), Type::any(), union.build()]) + } else { + union.build() + } } } diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 160e2ef193..98596b2d84 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -762,12 +762,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { }; let rhs_ty = inference.expression_type(right); - let rhs_class = match rhs_ty { - Type::ClassLiteral(class) => class, - Type::GenericAlias(alias) => alias.origin(self.db), - _ => { - continue; - } + let Type::ClassLiteral(rhs_class) = rhs_ty else { + continue; }; // `else`-branch narrowing for `if type(x) is Y` can only be done