[red-knot] implement basic call expression inference (#13164)

## Summary

Adds basic support for inferring the type resulting from a call
expression. This only works for the *result* of call expressions; it
performs no inference on parameters. It also intentionally does nothing
with class instantiation, `__call__` implementors, or lambdas.

## Test Plan

Adds a test that it infers the right thing!

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Chris Krycho 2024-08-30 13:51:29 -06:00 committed by GitHub
parent a73bebcf15
commit 28ab5f4065
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 93 additions and 18 deletions

View file

@ -320,6 +320,33 @@ impl<'db> Type<'db> {
} }
} }
/// Return the type resulting from calling an object of this type.
///
/// Returns `None` if `self` is not a callable type.
#[must_use]
pub fn call(&self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
Type::Function(function_type) => function_type.returns(db).or(Some(Type::Unknown)),
// TODO: handle class constructors
Type::Class(_class_ty) => Some(Type::Unknown),
// TODO: handle classes which implement the Callable protocol
Type::Instance(_instance_ty) => Some(Type::Unknown),
// `Any` is callable, and its return type is also `Any`.
Type::Any => Some(Type::Any),
Type::Unknown => Some(Type::Unknown),
// TODO: union and intersection types, if they reduce to `Callable`
Type::Union(_) => Some(Type::Unknown),
Type::Intersection(_) => Some(Type::Unknown),
_ => None,
}
}
#[must_use] #[must_use]
pub fn instance(&self) -> Type<'db> { pub fn instance(&self) -> Type<'db> {
match self { match self {
@ -550,4 +577,25 @@ mod tests {
let b_file_diagnostics = super::check_types(&db, b_file); let b_file_diagnostics = super::check_types(&db, b_file);
assert_eq!(&*b_file_diagnostics, &[]); assert_eq!(&*b_file_diagnostics, &[]);
} }
#[test]
fn invalid_callable() {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
nonsense = 123
x = nonsense()
",
)
.unwrap();
let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'Literal[123]' is not callable"],
);
}
} }

View file

@ -868,19 +868,16 @@ impl<'db> TypeInferenceBuilder<'db> {
value, value,
} = assignment; } = assignment;
// TODO remove once we infer definitions in unpacking assignment, since that infers the RHS
// too, and uses the `infer_expression_types` query to do it
self.infer_expression(value);
for target in targets { for target in targets {
match target { if let ast::Expr::Name(name) = target {
ast::Expr::Name(name) => { self.infer_definition(name);
self.infer_definition(name); } else {
} // TODO infer definitions in unpacking assignment. When we do, this duplication of
_ => { // the "get `Expression`, call `infer_expression_types` on it, `self.extend`" dance
// TODO infer definitions in unpacking assignment // will be removed; it'll all happen in `infer_assignment_definition` instead.
self.infer_expression(target); let expression = self.index.expression(value.as_ref());
} self.extend(infer_expression_types(self.db, expression));
self.infer_expression(target);
} }
} }
} }
@ -1363,7 +1360,8 @@ impl<'db> TypeInferenceBuilder<'db> {
}; };
let expr_id = expression.scoped_ast_id(self.db, self.scope); let expr_id = expression.scoped_ast_id(self.db, self.scope);
self.types.expressions.insert(expr_id, ty); let previous = self.types.expressions.insert(expr_id, ty);
assert!(previous.is_none());
ty ty
} }
@ -1746,10 +1744,18 @@ impl<'db> TypeInferenceBuilder<'db> {
} = call_expression; } = call_expression;
self.infer_arguments(arguments); self.infer_arguments(arguments);
self.infer_expression(func); let function_type = self.infer_expression(func);
function_type.call(self.db).unwrap_or_else(|| {
// TODO resolve to return type of `func`, if its a callable type self.add_diagnostic(
Type::Unknown func.as_ref().into(),
"call-non-callable",
format_args!(
"Object of type '{}' is not callable",
function_type.display(self.db)
),
);
Type::Unknown
})
} }
fn infer_starred_expression(&mut self, starred: &ast::ExprStarred) -> Type<'db> { fn infer_starred_expression(&mut self, starred: &ast::ExprStarred) -> Type<'db> {
@ -2247,7 +2253,8 @@ impl<'db> TypeInferenceBuilder<'db> {
}; };
let expr_id = expression.scoped_ast_id(self.db, self.scope); let expr_id = expression.scoped_ast_id(self.db, self.scope);
self.types.expressions.insert(expr_id, ty); let previous = self.types.expressions.insert(expr_id, ty);
assert!(previous.is_none());
ty ty
} }
@ -2808,6 +2815,25 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn basic_call_expression() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
def get_int() -> int:
return 42
x = get_int()
",
)?;
assert_public_ty(&db, "src/a.py", "x", "int");
Ok(())
}
#[test] #[test]
fn resolve_union() -> anyhow::Result<()> { fn resolve_union() -> anyhow::Result<()> {
let mut db = setup_db(); let mut db = setup_db();

View file

@ -24,6 +24,7 @@ const TOMLLIB_312_URL: &str = "https://raw.githubusercontent.com/python/cpython/
// The "unresolved import" is because we don't understand `*` imports yet. // The "unresolved import" is because we don't understand `*` imports yet.
static EXPECTED_DIAGNOSTICS: &[&str] = &[ static EXPECTED_DIAGNOSTICS: &[&str] = &[
"/src/tomllib/_parser.py:7:29: Module 'collections.abc' has no member 'Iterable'", "/src/tomllib/_parser.py:7:29: Module 'collections.abc' has no member 'Iterable'",
"/src/tomllib/_parser.py:686:23: Object of type 'Unbound' is not callable",
"Line 69 is too long (89 characters)", "Line 69 is too long (89 characters)",
"Use double quotes for strings", "Use double quotes for strings",
"Use double quotes for strings", "Use double quotes for strings",