mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-29 21:35:58 +00:00
[red-knot] Fix call expression inference edge case for decorated functions (#13191)
This commit is contained in:
parent
5661353334
commit
2014cba87f
2 changed files with 54 additions and 15 deletions
|
@ -326,7 +326,7 @@ impl<'db> Type<'db> {
|
|||
#[must_use]
|
||||
pub fn call(&self, db: &'db dyn Db) -> Option<Type<'db>> {
|
||||
match self {
|
||||
Type::Function(function_type) => function_type.returns(db).or(Some(Type::Unknown)),
|
||||
Type::Function(function_type) => Some(function_type.return_type(db)),
|
||||
|
||||
// TODO annotated return type on `__new__` or metaclass `__call__`
|
||||
Type::Class(class) => Some(Type::Instance(*class)),
|
||||
|
@ -374,21 +374,34 @@ impl<'db> FunctionType<'db> {
|
|||
self.decorators(db).contains(&decorator)
|
||||
}
|
||||
|
||||
/// annotated return type for this function, if any
|
||||
pub fn returns(&self, db: &'db dyn Db) -> Option<Type<'db>> {
|
||||
/// inferred return type for this function
|
||||
pub fn return_type(&self, db: &'db dyn Db) -> Type<'db> {
|
||||
let definition = self.definition(db);
|
||||
let DefinitionKind::Function(function_stmt_node) = definition.node(db) else {
|
||||
panic!("Function type definition must have `DefinitionKind::Function`")
|
||||
};
|
||||
|
||||
function_stmt_node.returns.as_ref().map(|returns| {
|
||||
if function_stmt_node.is_async {
|
||||
// TODO: generic `types.CoroutineType`!
|
||||
Type::Unknown
|
||||
} else {
|
||||
definition_expression_ty(db, definition, returns.as_ref())
|
||||
}
|
||||
})
|
||||
// TODO if a function `bar` is decorated by `foo`,
|
||||
// where `foo` is annotated as returning a type `X` that is a subtype of `Callable`,
|
||||
// we need to infer the return type from `X`'s return annotation
|
||||
// rather than from `bar`'s return annotation
|
||||
// in order to determine the type that `bar` returns
|
||||
if !function_stmt_node.decorator_list.is_empty() {
|
||||
return Type::Unknown;
|
||||
}
|
||||
|
||||
function_stmt_node
|
||||
.returns
|
||||
.as_ref()
|
||||
.map(|returns| {
|
||||
if function_stmt_node.is_async {
|
||||
// TODO: generic `types.CoroutineType`!
|
||||
Type::Unknown
|
||||
} else {
|
||||
definition_expression_ty(db, definition, returns.as_ref())
|
||||
}
|
||||
})
|
||||
.unwrap_or(Type::Unknown)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2806,10 +2806,7 @@ mod tests {
|
|||
panic!("example is not a function");
|
||||
};
|
||||
|
||||
let returns = function
|
||||
.returns(&db)
|
||||
.expect("There is a return type on the function");
|
||||
|
||||
let returns = function.return_type(&db);
|
||||
assert_eq!(returns.display(&db).to_string(), "int");
|
||||
|
||||
Ok(())
|
||||
|
@ -2854,6 +2851,35 @@ mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn basic_decorated_call_expression() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
"
|
||||
from typing import Callable
|
||||
|
||||
def foo() -> int:
|
||||
return 42
|
||||
|
||||
def decorator(func) -> Callable[[], int]:
|
||||
return foo
|
||||
|
||||
@decorator
|
||||
def bar() -> str:
|
||||
return 'bar'
|
||||
|
||||
x = bar()
|
||||
",
|
||||
)?;
|
||||
|
||||
// TODO: should be `int`!
|
||||
assert_public_ty(&db, "src/a.py", "x", "Unknown");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn class_constructor_call_expression() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue