Support classes that implement __call__ (#13580)

## Summary

This looked straightforward and removes some TODOs.
This commit is contained in:
Charlie Marsh 2024-10-01 13:15:46 -04:00 committed by GitHub
parent 043fba7a57
commit edba60106b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 45 additions and 2 deletions

View file

@ -609,8 +609,20 @@ impl<'db> Type<'db> {
})
}
// TODO: handle classes which implement the `__call__` protocol
Type::Instance(_instance_ty) => CallOutcome::callable(Type::Todo),
Type::Instance(class) => {
// Since `__call__` is a dunder, we need to access it as an attribute on the class
// rather than the instance (matching runtime semantics).
let meta_ty = Type::Class(class);
let dunder_call_method = meta_ty.member(db, "__call__");
if dunder_call_method.is_unbound() {
CallOutcome::not_callable(self)
} else {
let args = std::iter::once(self)
.chain(arg_types.iter().copied())
.collect::<Vec<_>>();
dunder_call_method.call(db, &args)
}
}
// `Any` is callable, and its return type is also `Any`.
Type::Any => CallOutcome::callable(Type::Any),

View file

@ -6723,6 +6723,37 @@ mod tests {
Ok(())
}
#[test]
fn dunder_call() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
"
class Multiplier:
def __init__(self, factor: float):
self.factor = factor
def __call__(self, number: float) -> float:
return number * self.factor
a = Multiplier(2.0)(3.0)
class Unit:
...
b = Unit()(3.0)
",
)?;
assert_public_ty(&db, "/src/a.py", "a", "float");
assert_public_ty(&db, "/src/a.py", "b", "Unknown");
assert_file_diagnostics(&db, "src/a.py", &["Object of type 'Unit' is not callable."]);
Ok(())
}
#[test]
fn boolean_or_expression() -> anyhow::Result<()> {
let mut db = setup_db();