mirror of
https://github.com/astral-sh/ruff.git
synced 2025-10-02 06:42:02 +00:00
[red-knot] implement basic call expression inference (#13164)
## Summary Adds basic support for inferring the type resulting from a call expression. This only works for the *result* of call expressions; it performs no inference on parameters. It also intentionally does nothing with class instantiation, `__call__` implementors, or lambdas. ## Test Plan Adds a test that it infers the right thing! --------- Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
parent
a73bebcf15
commit
28ab5f4065
3 changed files with 93 additions and 18 deletions
|
@ -320,6 +320,33 @@ impl<'db> Type<'db> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return the type resulting from calling an object of this type.
|
||||||
|
///
|
||||||
|
/// Returns `None` if `self` is not a callable type.
|
||||||
|
#[must_use]
|
||||||
|
pub fn call(&self, db: &'db dyn Db) -> Option<Type<'db>> {
|
||||||
|
match self {
|
||||||
|
Type::Function(function_type) => function_type.returns(db).or(Some(Type::Unknown)),
|
||||||
|
|
||||||
|
// TODO: handle class constructors
|
||||||
|
Type::Class(_class_ty) => Some(Type::Unknown),
|
||||||
|
|
||||||
|
// TODO: handle classes which implement the Callable protocol
|
||||||
|
Type::Instance(_instance_ty) => Some(Type::Unknown),
|
||||||
|
|
||||||
|
// `Any` is callable, and its return type is also `Any`.
|
||||||
|
Type::Any => Some(Type::Any),
|
||||||
|
|
||||||
|
Type::Unknown => Some(Type::Unknown),
|
||||||
|
|
||||||
|
// TODO: union and intersection types, if they reduce to `Callable`
|
||||||
|
Type::Union(_) => Some(Type::Unknown),
|
||||||
|
Type::Intersection(_) => Some(Type::Unknown),
|
||||||
|
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn instance(&self) -> Type<'db> {
|
pub fn instance(&self) -> Type<'db> {
|
||||||
match self {
|
match self {
|
||||||
|
@ -550,4 +577,25 @@ mod tests {
|
||||||
let b_file_diagnostics = super::check_types(&db, b_file);
|
let b_file_diagnostics = super::check_types(&db, b_file);
|
||||||
assert_eq!(&*b_file_diagnostics, &[]);
|
assert_eq!(&*b_file_diagnostics, &[]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_callable() {
|
||||||
|
let mut db = setup_db();
|
||||||
|
|
||||||
|
db.write_dedented(
|
||||||
|
"src/a.py",
|
||||||
|
"
|
||||||
|
nonsense = 123
|
||||||
|
x = nonsense()
|
||||||
|
",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
|
||||||
|
let a_file_diagnostics = super::check_types(&db, a_file);
|
||||||
|
assert_diagnostic_messages(
|
||||||
|
&a_file_diagnostics,
|
||||||
|
&["Object of type 'Literal[123]' is not callable"],
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -868,19 +868,16 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
value,
|
value,
|
||||||
} = assignment;
|
} = assignment;
|
||||||
|
|
||||||
// TODO remove once we infer definitions in unpacking assignment, since that infers the RHS
|
|
||||||
// too, and uses the `infer_expression_types` query to do it
|
|
||||||
self.infer_expression(value);
|
|
||||||
|
|
||||||
for target in targets {
|
for target in targets {
|
||||||
match target {
|
if let ast::Expr::Name(name) = target {
|
||||||
ast::Expr::Name(name) => {
|
self.infer_definition(name);
|
||||||
self.infer_definition(name);
|
} else {
|
||||||
}
|
// TODO infer definitions in unpacking assignment. When we do, this duplication of
|
||||||
_ => {
|
// the "get `Expression`, call `infer_expression_types` on it, `self.extend`" dance
|
||||||
// TODO infer definitions in unpacking assignment
|
// will be removed; it'll all happen in `infer_assignment_definition` instead.
|
||||||
self.infer_expression(target);
|
let expression = self.index.expression(value.as_ref());
|
||||||
}
|
self.extend(infer_expression_types(self.db, expression));
|
||||||
|
self.infer_expression(target);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1363,7 +1360,8 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
};
|
};
|
||||||
|
|
||||||
let expr_id = expression.scoped_ast_id(self.db, self.scope);
|
let expr_id = expression.scoped_ast_id(self.db, self.scope);
|
||||||
self.types.expressions.insert(expr_id, ty);
|
let previous = self.types.expressions.insert(expr_id, ty);
|
||||||
|
assert!(previous.is_none());
|
||||||
|
|
||||||
ty
|
ty
|
||||||
}
|
}
|
||||||
|
@ -1746,10 +1744,18 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
} = call_expression;
|
} = call_expression;
|
||||||
|
|
||||||
self.infer_arguments(arguments);
|
self.infer_arguments(arguments);
|
||||||
self.infer_expression(func);
|
let function_type = self.infer_expression(func);
|
||||||
|
function_type.call(self.db).unwrap_or_else(|| {
|
||||||
// TODO resolve to return type of `func`, if its a callable type
|
self.add_diagnostic(
|
||||||
Type::Unknown
|
func.as_ref().into(),
|
||||||
|
"call-non-callable",
|
||||||
|
format_args!(
|
||||||
|
"Object of type '{}' is not callable",
|
||||||
|
function_type.display(self.db)
|
||||||
|
),
|
||||||
|
);
|
||||||
|
Type::Unknown
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_starred_expression(&mut self, starred: &ast::ExprStarred) -> Type<'db> {
|
fn infer_starred_expression(&mut self, starred: &ast::ExprStarred) -> Type<'db> {
|
||||||
|
@ -2247,7 +2253,8 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
};
|
};
|
||||||
|
|
||||||
let expr_id = expression.scoped_ast_id(self.db, self.scope);
|
let expr_id = expression.scoped_ast_id(self.db, self.scope);
|
||||||
self.types.expressions.insert(expr_id, ty);
|
let previous = self.types.expressions.insert(expr_id, ty);
|
||||||
|
assert!(previous.is_none());
|
||||||
|
|
||||||
ty
|
ty
|
||||||
}
|
}
|
||||||
|
@ -2808,6 +2815,25 @@ mod tests {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn basic_call_expression() -> anyhow::Result<()> {
|
||||||
|
let mut db = setup_db();
|
||||||
|
|
||||||
|
db.write_dedented(
|
||||||
|
"src/a.py",
|
||||||
|
"
|
||||||
|
def get_int() -> int:
|
||||||
|
return 42
|
||||||
|
|
||||||
|
x = get_int()
|
||||||
|
",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
assert_public_ty(&db, "src/a.py", "x", "int");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resolve_union() -> anyhow::Result<()> {
|
fn resolve_union() -> anyhow::Result<()> {
|
||||||
let mut db = setup_db();
|
let mut db = setup_db();
|
||||||
|
|
|
@ -24,6 +24,7 @@ const TOMLLIB_312_URL: &str = "https://raw.githubusercontent.com/python/cpython/
|
||||||
// The "unresolved import" is because we don't understand `*` imports yet.
|
// The "unresolved import" is because we don't understand `*` imports yet.
|
||||||
static EXPECTED_DIAGNOSTICS: &[&str] = &[
|
static EXPECTED_DIAGNOSTICS: &[&str] = &[
|
||||||
"/src/tomllib/_parser.py:7:29: Module 'collections.abc' has no member 'Iterable'",
|
"/src/tomllib/_parser.py:7:29: Module 'collections.abc' has no member 'Iterable'",
|
||||||
|
"/src/tomllib/_parser.py:686:23: Object of type 'Unbound' is not callable",
|
||||||
"Line 69 is too long (89 characters)",
|
"Line 69 is too long (89 characters)",
|
||||||
"Use double quotes for strings",
|
"Use double quotes for strings",
|
||||||
"Use double quotes for strings",
|
"Use double quotes for strings",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue