[red-knot] infer basic (name-based) annotation expressions (#13130)

## Summary

- Introduce methods for inferring annotation and type expressions.
- Correctly infer explicit return types from functions where they are
simple names that can be resolved in scope.

Contributes to #12701 by way of helping unlock call expressions (this
does not remotely finish that, as it stands, but it gets us moving that
direction).

## Test Plan

Added a test for function return types which use the name form of an
annotation expression, since this is aiming toward call expressions.
When we extend this to working for other annotation and type expression
positions, we should add explicit tests for those as well.

---------

Co-authored-by: Alex Waygood <alex.waygood@gmail.com>
Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Chris Krycho 2024-08-30 09:24:36 -06:00 committed by GitHub
parent 34b4732c46
commit f8656ff35e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 238 additions and 8 deletions

View file

@ -336,6 +336,8 @@ pub struct FunctionType<'db> {
/// name of the function at definition
pub name: ast::name::Name,
definition: Definition<'db>,
/// types of all decorators on this function
decorators: Vec<Type<'db>>,
}
@ -344,6 +346,19 @@ impl<'db> FunctionType<'db> {
pub fn has_decorator(self, db: &dyn Db, decorator: Type<'_>) -> bool {
self.decorators(db).contains(&decorator)
}
/// annotated return type for this function, if any
pub fn returns(&self, db: &'db dyn Db) -> Option<Type<'db>> {
let definition = self.definition(db);
let DefinitionKind::Function(function_stmt_node) = definition.node(db) else {
panic!("Function type definition must have `DefinitionKind::Function`")
};
function_stmt_node
.returns
.as_ref()
.map(|returns| definition_expression_ty(db, definition, returns.as_ref()))
}
}
#[salsa::interned]

View file

@ -77,6 +77,7 @@ fn infer_definition_types_cycle_recovery<'db>(
_cycle: &salsa::Cycle,
input: Definition<'db>,
) -> TypeInference<'db> {
tracing::trace!("infer_definition_types_cycle_recovery");
let mut inference = TypeInference::default();
inference.definitions.insert(input, Type::Unknown);
inference
@ -420,9 +421,7 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_region_deferred(&mut self, definition: Definition<'db>) {
match definition.node(self.db) {
DefinitionKind::Function(_function) => {
// TODO self.infer_function_deferred(function.node());
}
DefinitionKind::Function(function) => self.infer_function_deferred(function.node()),
DefinitionKind::Class(class) => self.infer_class_deferred(class.node()),
DefinitionKind::AnnotatedAssignment(_annotated_assignment) => {
// TODO self.infer_annotated_assignment_deferred(annotated_assignment.node());
@ -460,7 +459,12 @@ impl<'db> TypeInferenceBuilder<'db> {
let Some(type_params) = function.type_params.as_deref() else {
panic!("function type params scope without type params");
};
self.infer_optional_expression(function.returns.as_deref());
// TODO: this should also be applied to parameter annotations.
if !self.is_stub() {
self.infer_optional_expression(function.returns.as_deref());
}
self.infer_type_parameters(type_params);
self.infer_parameters(&function.parameters);
}
@ -549,14 +553,23 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_expression(default);
}
// If there are type params, parameters and returns are evaluated in that scope.
// If there are type params, parameters and returns are evaluated in that scope, that is, in
// `infer_function_type_params`, rather than here.
if type_params.is_none() {
self.infer_parameters(parameters);
self.infer_optional_expression(returns.as_deref());
// TODO: this should also be applied to parameter annotations.
if !self.is_stub() {
self.infer_optional_annotation_expression(returns.as_deref());
}
}
let function_ty =
Type::Function(FunctionType::new(self.db, name.id.clone(), decorator_tys));
let function_ty = Type::Function(FunctionType::new(
self.db,
name.id.clone(),
definition,
decorator_tys,
));
self.types.definitions.insert(definition, function_ty);
}
@ -670,6 +683,13 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}
fn infer_function_deferred(&mut self, function: &ast::StmtFunctionDef) {
if self.is_stub() {
self.types.has_deferred = true;
self.infer_optional_annotation_expression(function.returns.as_deref());
}
}
fn infer_class_deferred(&mut self, class: &ast::StmtClassDef) {
if self.is_stub() {
self.types.has_deferred = true;
@ -1297,6 +1317,13 @@ impl<'db> TypeInferenceBuilder<'db> {
expression.map(|expr| self.infer_expression(expr))
}
fn infer_optional_annotation_expression(
&mut self,
expr: Option<&ast::Expr>,
) -> Option<Type<'db>> {
expr.map(|expr| self.infer_annotation_expression(expr))
}
fn infer_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
let ty = match expression {
ast::Expr::NoneLiteral(ast::ExprNoneLiteral { range: _ }) => Type::None,
@ -2059,6 +2086,173 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}
/// Annotation expressions.
impl<'db> TypeInferenceBuilder<'db> {
fn infer_annotation_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
// https://typing.readthedocs.io/en/latest/spec/annotations.html#grammar-token-expression-grammar-annotation_expression
match expression {
// TODO: parse the expression and check whether it is a string annotation, since they
// can be annotation expressions distinct from type expressions.
// https://typing.readthedocs.io/en/latest/spec/annotations.html#string-annotations
ast::Expr::StringLiteral(_literal) => Type::Unknown,
// Annotation expressions also get special handling for `*args` and `**kwargs`.
ast::Expr::Starred(starred) => self.infer_starred_expression(starred),
// All other annotation expressions are (possibly) valid type expressions, so handle
// them there instead.
type_expr => self.infer_type_expression(type_expr),
}
}
}
/// Type expressions
impl<'db> TypeInferenceBuilder<'db> {
fn infer_type_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
// https://typing.readthedocs.io/en/latest/spec/annotations.html#grammar-token-expression-grammar-type_expression
// TODO: this does not include any of the special forms, and is only a
// stub of the forms other than a standalone name in scope.
let ty = match expression {
ast::Expr::Name(name) => {
debug_assert!(
name.ctx.is_load(),
"name in a type expression is always 'load' but got: '{:?}'",
name.ctx
);
self.infer_name_expression(name).instance()
}
ast::Expr::NoneLiteral(_literal) => Type::None,
// TODO: parse the expression and check whether it is a string annotation.
// https://typing.readthedocs.io/en/latest/spec/annotations.html#string-annotations
ast::Expr::StringLiteral(_literal) => Type::Unknown,
// TODO: an Ellipsis literal *on its own* does not have any meaning in annotation
// expressions, but is meaningful in the context of a number of special forms.
ast::Expr::EllipsisLiteral(_literal) => Type::Unknown,
// Other literals do not have meaningful values in the annotation expression context.
// However, we will we want to handle these differently when working with special forms,
// since (e.g.) `123` is not valid in an annotation expression but `Literal[123]` is.
ast::Expr::BytesLiteral(_literal) => Type::Unknown,
ast::Expr::NumberLiteral(_literal) => Type::Unknown,
ast::Expr::BooleanLiteral(_literal) => Type::Unknown,
// Forms which are invalid in the context of annotation expressions: we infer their
// nested expressions as normal expressions, but the type of the top-level expression is
// always `Type::Unknown` in these cases.
ast::Expr::BoolOp(bool_op) => {
self.infer_boolean_expression(bool_op);
Type::Unknown
}
ast::Expr::Named(named) => {
self.infer_named_expression(named);
Type::Unknown
}
ast::Expr::BinOp(binary) => {
self.infer_binary_expression(binary);
Type::Unknown
}
ast::Expr::UnaryOp(unary) => {
self.infer_unary_expression(unary);
Type::Unknown
}
ast::Expr::Lambda(lambda_expression) => {
self.infer_lambda_expression(lambda_expression);
Type::Unknown
}
ast::Expr::If(if_expression) => {
self.infer_if_expression(if_expression);
Type::Unknown
}
ast::Expr::Dict(dict) => {
self.infer_dict_expression(dict);
Type::Unknown
}
ast::Expr::Set(set) => {
self.infer_set_expression(set);
Type::Unknown
}
ast::Expr::ListComp(listcomp) => {
self.infer_list_comprehension_expression(listcomp);
Type::Unknown
}
ast::Expr::SetComp(setcomp) => {
self.infer_set_comprehension_expression(setcomp);
Type::Unknown
}
ast::Expr::DictComp(dictcomp) => {
self.infer_dict_comprehension_expression(dictcomp);
Type::Unknown
}
ast::Expr::Generator(generator) => {
self.infer_generator_expression(generator);
Type::Unknown
}
ast::Expr::Await(await_expression) => {
self.infer_await_expression(await_expression);
Type::Unknown
}
ast::Expr::Yield(yield_expression) => {
self.infer_yield_expression(yield_expression);
Type::Unknown
}
ast::Expr::YieldFrom(yield_from) => {
self.infer_yield_from_expression(yield_from);
Type::Unknown
}
ast::Expr::Compare(compare) => {
self.infer_compare_expression(compare);
Type::Unknown
}
ast::Expr::Call(call_expr) => {
self.infer_call_expression(call_expr);
Type::Unknown
}
ast::Expr::FString(fstring) => {
self.infer_fstring_expression(fstring);
Type::Unknown
}
//
ast::Expr::Attribute(attribute) => {
self.infer_attribute_expression(attribute);
Type::Unknown
}
// TODO: this may be a place we need to revisit with special forms.
ast::Expr::Subscript(subscript) => {
self.infer_subscript_expression(subscript);
Type::Unknown
}
ast::Expr::Starred(starred) => {
self.infer_starred_expression(starred);
Type::Unknown
}
ast::Expr::List(list) => {
self.infer_list_expression(list);
Type::Unknown
}
ast::Expr::Tuple(tuple) => {
self.infer_tuple_expression(tuple);
Type::Unknown
}
ast::Expr::Slice(slice) => {
self.infer_slice_expression(slice);
Type::Unknown
}
ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"),
};
let expr_id = expression.scoped_ast_id(self.db, self.scope);
self.types.expressions.insert(expr_id, ty);
ty
}
}
fn format_import_from_module(level: u32, module: Option<&str>) -> String {
format!(
"{}{}",
@ -2593,6 +2787,27 @@ mod tests {
Ok(())
}
#[test]
fn function_return_type() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_file("src/a.py", "def example() -> int: return 42")?;
let mod_file = system_path_to_file(&db, "src/a.py").unwrap();
let ty = global_symbol_ty_by_name(&db, mod_file, "example");
let Type::Function(function) = ty else {
panic!("example is not a function");
};
let returns = function
.returns(&db)
.expect("There is a return type on the function");
assert_eq!(returns.display(&db).to_string(), "int");
Ok(())
}
#[test]
fn resolve_union() -> anyhow::Result<()> {
let mut db = setup_db();