[red-knot] Fix call expression inference edge case for decorated functions (#13191)

This commit is contained in:
Alex Waygood 2024-09-01 16:19:40 +01:00 committed by GitHub
parent 5661353334
commit 2014cba87f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 54 additions and 15 deletions

View file

@ -326,7 +326,7 @@ impl<'db> Type<'db> {
#[must_use]
pub fn call(&self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
Type::Function(function_type) => function_type.returns(db).or(Some(Type::Unknown)),
Type::Function(function_type) => Some(function_type.return_type(db)),
// TODO annotated return type on `__new__` or metaclass `__call__`
Type::Class(class) => Some(Type::Instance(*class)),
@ -374,14 +374,26 @@ impl<'db> FunctionType<'db> {
self.decorators(db).contains(&decorator)
}
/// annotated return type for this function, if any
pub fn returns(&self, db: &'db dyn Db) -> Option<Type<'db>> {
/// inferred return type for this function
pub fn return_type(&self, db: &'db dyn Db) -> Type<'db> {
let definition = self.definition(db);
let DefinitionKind::Function(function_stmt_node) = definition.node(db) else {
panic!("Function type definition must have `DefinitionKind::Function`")
};
function_stmt_node.returns.as_ref().map(|returns| {
// TODO if a function `bar` is decorated by `foo`,
// where `foo` is annotated as returning a type `X` that is a subtype of `Callable`,
// we need to infer the return type from `X`'s return annotation
// rather than from `bar`'s return annotation
// in order to determine the type that `bar` returns
if !function_stmt_node.decorator_list.is_empty() {
return Type::Unknown;
}
function_stmt_node
.returns
.as_ref()
.map(|returns| {
if function_stmt_node.is_async {
// TODO: generic `types.CoroutineType`!
Type::Unknown
@ -389,6 +401,7 @@ impl<'db> FunctionType<'db> {
definition_expression_ty(db, definition, returns.as_ref())
}
})
.unwrap_or(Type::Unknown)
}
}

View file

@ -2806,10 +2806,7 @@ mod tests {
panic!("example is not a function");
};
let returns = function
.returns(&db)
.expect("There is a return type on the function");
let returns = function.return_type(&db);
assert_eq!(returns.display(&db).to_string(), "int");
Ok(())
@ -2854,6 +2851,35 @@ mod tests {
Ok(())
}
#[test]
fn basic_decorated_call_expression() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
from typing import Callable
def foo() -> int:
return 42
def decorator(func) -> Callable[[], int]:
return foo
@decorator
def bar() -> str:
return 'bar'
x = bar()
",
)?;
// TODO: should be `int`!
assert_public_ty(&db, "src/a.py", "x", "Unknown");
Ok(())
}
#[test]
fn class_constructor_call_expression() -> anyhow::Result<()> {
let mut db = setup_db();