From edba60106bbbb80c81c3b4540a4247368441b4dc Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Tue, 1 Oct 2024 13:15:46 -0400 Subject: [PATCH] Support classes that implement `__call__` (#13580) ## Summary This looked straightforward and removes some TODOs. --- crates/red_knot_python_semantic/src/types.rs | 16 ++++++++-- .../src/types/infer.rs | 31 +++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 6b302fc4e1..50bab3d555 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -609,8 +609,20 @@ impl<'db> Type<'db> { }) } - // TODO: handle classes which implement the `__call__` protocol - Type::Instance(_instance_ty) => CallOutcome::callable(Type::Todo), + Type::Instance(class) => { + // Since `__call__` is a dunder, we need to access it as an attribute on the class + // rather than the instance (matching runtime semantics). + let meta_ty = Type::Class(class); + let dunder_call_method = meta_ty.member(db, "__call__"); + if dunder_call_method.is_unbound() { + CallOutcome::not_callable(self) + } else { + let args = std::iter::once(self) + .chain(arg_types.iter().copied()) + .collect::>(); + dunder_call_method.call(db, &args) + } + } // `Any` is callable, and its return type is also `Any`. Type::Any => CallOutcome::callable(Type::Any), diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 72fe24064c..f1444663cd 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -6723,6 +6723,37 @@ mod tests { Ok(()) } + #[test] + fn dunder_call() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + class Multiplier: + def __init__(self, factor: float): + self.factor = factor + + def __call__(self, number: float) -> float: + return number * self.factor + + a = Multiplier(2.0)(3.0) + + class Unit: + ... + + b = Unit()(3.0) + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "float"); + assert_public_ty(&db, "/src/a.py", "b", "Unknown"); + + assert_file_diagnostics(&db, "src/a.py", &["Object of type 'Unit' is not callable."]); + + Ok(()) + } + #[test] fn boolean_or_expression() -> anyhow::Result<()> { let mut db = setup_db();