[red-knot] function parameter types (#14802)

## Summary

Inferred and declared types for function parameters, in the function
body scope.

Fixes #13693.

## Test Plan

Added mdtests.

---------

Co-authored-by: Micha Reiser <micha@reiser.io>
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Carl Meyer 2024-12-06 12:55:56 -08:00 committed by GitHub
parent 2119dcab6f
commit 3017b3b687
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 340 additions and 97 deletions

View file

@ -606,24 +606,11 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
let function_table = index.symbol_table(function_scope_id);
assert_eq!(
names(&function_table),
vec!["a", "b", "c", "args", "d", "kwargs"],
vec!["a", "b", "c", "d", "args", "kwargs"],
);
let use_def = index.use_def_map(function_scope_id);
for name in ["a", "b", "c", "d"] {
let binding = use_def
.first_public_binding(
function_table
.symbol_id_by_name(name)
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
binding.kind(&db),
DefinitionKind::ParameterWithDefault(_)
));
}
for name in ["args", "kwargs"] {
let binding = use_def
.first_public_binding(
function_table
@ -633,6 +620,28 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
.unwrap();
assert!(matches!(binding.kind(&db), DefinitionKind::Parameter(_)));
}
let args_binding = use_def
.first_public_binding(
function_table
.symbol_id_by_name("args")
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
args_binding.kind(&db),
DefinitionKind::VariadicPositionalParameter(_)
));
let kwargs_binding = use_def
.first_public_binding(
function_table
.symbol_id_by_name("kwargs")
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
kwargs_binding.kind(&db),
DefinitionKind::VariadicKeywordParameter(_)
));
}
#[test]
@ -654,25 +663,38 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
let lambda_table = index.symbol_table(lambda_scope_id);
assert_eq!(
names(&lambda_table),
vec!["a", "b", "c", "args", "d", "kwargs"],
vec!["a", "b", "c", "d", "args", "kwargs"],
);
let use_def = index.use_def_map(lambda_scope_id);
for name in ["a", "b", "c", "d"] {
let binding = use_def
.first_public_binding(lambda_table.symbol_id_by_name(name).expect("symbol exists"))
.unwrap();
assert!(matches!(
binding.kind(&db),
DefinitionKind::ParameterWithDefault(_)
));
}
for name in ["args", "kwargs"] {
let binding = use_def
.first_public_binding(lambda_table.symbol_id_by_name(name).expect("symbol exists"))
.unwrap();
assert!(matches!(binding.kind(&db), DefinitionKind::Parameter(_)));
}
let args_binding = use_def
.first_public_binding(
lambda_table
.symbol_id_by_name("args")
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
args_binding.kind(&db),
DefinitionKind::VariadicPositionalParameter(_)
));
let kwargs_binding = use_def
.first_public_binding(
lambda_table
.symbol_id_by_name("kwargs")
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
kwargs_binding.kind(&db),
DefinitionKind::VariadicKeywordParameter(_)
));
}
/// Test case to validate that the comprehension scope is correctly identified and that the target

View file

@ -9,7 +9,7 @@ use ruff_index::IndexVec;
use ruff_python_ast as ast;
use ruff_python_ast::name::Name;
use ruff_python_ast::visitor::{walk_expr, walk_pattern, walk_stmt, Visitor};
use ruff_python_ast::{AnyParameterRef, BoolOp, Expr};
use ruff_python_ast::{BoolOp, Expr};
use crate::ast_node_ref::AstNodeRef;
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
@ -479,21 +479,35 @@ impl<'db> SemanticIndexBuilder<'db> {
self.pop_scope();
}
fn declare_parameter(&mut self, parameter: AnyParameterRef<'db>) {
let symbol = self.add_symbol(parameter.name().id().clone());
fn declare_parameters(&mut self, parameters: &'db ast::Parameters) {
for parameter in parameters.iter_non_variadic_params() {
self.declare_parameter(parameter);
}
if let Some(vararg) = parameters.vararg.as_ref() {
let symbol = self.add_symbol(vararg.name.id().clone());
self.add_definition(
symbol,
DefinitionNodeRef::VariadicPositionalParameter(vararg),
);
}
if let Some(kwarg) = parameters.kwarg.as_ref() {
let symbol = self.add_symbol(kwarg.name.id().clone());
self.add_definition(symbol, DefinitionNodeRef::VariadicKeywordParameter(kwarg));
}
}
fn declare_parameter(&mut self, parameter: &'db ast::ParameterWithDefault) {
let symbol = self.add_symbol(parameter.parameter.name.id().clone());
let definition = self.add_definition(symbol, parameter);
if let AnyParameterRef::NonVariadic(with_default) = parameter {
// Insert a mapping from the parameter to the same definition.
// This ensures that calling `HasTy::ty` on the inner parameter returns
// a valid type (and doesn't panic)
let existing_definition = self.definitions_by_node.insert(
DefinitionNodeRef::from(AnyParameterRef::Variadic(&with_default.parameter)).key(),
definition,
);
debug_assert_eq!(existing_definition, None);
}
// Insert a mapping from the inner Parameter node to the same definition.
// This ensures that calling `HasTy::ty` on the inner parameter returns
// a valid type (and doesn't panic)
let existing_definition = self
.definitions_by_node
.insert((&parameter.parameter).into(), definition);
debug_assert_eq!(existing_definition, None);
}
pub(super) fn build(mut self) -> SemanticIndex<'db> {
@ -556,34 +570,40 @@ where
fn visit_stmt(&mut self, stmt: &'ast ast::Stmt) {
match stmt {
ast::Stmt::FunctionDef(function_def) => {
for decorator in &function_def.decorator_list {
let ast::StmtFunctionDef {
decorator_list,
parameters,
type_params,
name,
returns,
body,
is_async: _,
range: _,
} = function_def;
for decorator in decorator_list {
self.visit_decorator(decorator);
}
self.with_type_params(
NodeWithScopeRef::FunctionTypeParameters(function_def),
function_def.type_params.as_deref(),
type_params.as_deref(),
|builder| {
builder.visit_parameters(&function_def.parameters);
if let Some(expr) = &function_def.returns {
builder.visit_annotation(expr);
builder.visit_parameters(parameters);
if let Some(returns) = returns {
builder.visit_annotation(returns);
}
builder.push_scope(NodeWithScopeRef::Function(function_def));
// Add symbols and definitions for the parameters to the function scope.
for parameter in &*function_def.parameters {
builder.declare_parameter(parameter);
}
builder.declare_parameters(parameters);
builder.visit_body(&function_def.body);
builder.visit_body(body);
builder.pop_scope()
},
);
// The default value of the parameters needs to be evaluated in the
// enclosing scope.
for default in function_def
.parameters
for default in parameters
.iter_non_variadic_params()
.filter_map(|param| param.default.as_deref())
{
@ -592,7 +612,7 @@ where
// The symbol for the function name itself has to be evaluated
// at the end to match the runtime evaluation of parameter defaults
// and return-type annotations.
let symbol = self.add_symbol(function_def.name.id.clone());
let symbol = self.add_symbol(name.id.clone());
self.add_definition(symbol, function_def);
}
ast::Stmt::ClassDef(class) => {
@ -1179,10 +1199,8 @@ where
self.push_scope(NodeWithScopeRef::Lambda(lambda));
// Add symbols and definitions for the parameters to the lambda scope.
if let Some(parameters) = &lambda.parameters {
for parameter in parameters {
self.declare_parameter(parameter);
}
if let Some(parameters) = lambda.parameters.as_ref() {
self.declare_parameters(parameters);
}
self.visit_expr(lambda.body.as_ref());

View file

@ -89,7 +89,9 @@ pub(crate) enum DefinitionNodeRef<'a> {
AnnotatedAssignment(&'a ast::StmtAnnAssign),
AugmentedAssignment(&'a ast::StmtAugAssign),
Comprehension(ComprehensionDefinitionNodeRef<'a>),
Parameter(ast::AnyParameterRef<'a>),
VariadicPositionalParameter(&'a ast::Parameter),
VariadicKeywordParameter(&'a ast::Parameter),
Parameter(&'a ast::ParameterWithDefault),
WithItem(WithItemDefinitionNodeRef<'a>),
MatchPattern(MatchPatternDefinitionNodeRef<'a>),
ExceptHandler(ExceptHandlerDefinitionNodeRef<'a>),
@ -188,8 +190,8 @@ impl<'a> From<ComprehensionDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
}
}
impl<'a> From<ast::AnyParameterRef<'a>> for DefinitionNodeRef<'a> {
fn from(node: ast::AnyParameterRef<'a>) -> Self {
impl<'a> From<&'a ast::ParameterWithDefault> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::ParameterWithDefault) -> Self {
Self::Parameter(node)
}
}
@ -315,14 +317,15 @@ impl<'db> DefinitionNodeRef<'db> {
first,
is_async,
}),
DefinitionNodeRef::Parameter(parameter) => match parameter {
ast::AnyParameterRef::Variadic(parameter) => {
DefinitionKind::Parameter(AstNodeRef::new(parsed, parameter))
}
ast::AnyParameterRef::NonVariadic(parameter) => {
DefinitionKind::ParameterWithDefault(AstNodeRef::new(parsed, parameter))
}
},
DefinitionNodeRef::VariadicPositionalParameter(parameter) => {
DefinitionKind::VariadicPositionalParameter(AstNodeRef::new(parsed, parameter))
}
DefinitionNodeRef::VariadicKeywordParameter(parameter) => {
DefinitionKind::VariadicKeywordParameter(AstNodeRef::new(parsed, parameter))
}
DefinitionNodeRef::Parameter(parameter) => {
DefinitionKind::Parameter(AstNodeRef::new(parsed, parameter))
}
DefinitionNodeRef::WithItem(WithItemDefinitionNodeRef {
node,
target,
@ -384,10 +387,9 @@ impl<'db> DefinitionNodeRef<'db> {
is_async: _,
}) => target.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(),
Self::Parameter(node) => match node {
ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
},
Self::VariadicPositionalParameter(node) => node.into(),
Self::VariadicKeywordParameter(node) => node.into(),
Self::Parameter(node) => node.into(),
Self::WithItem(WithItemDefinitionNodeRef {
node: _,
target,
@ -452,8 +454,9 @@ pub enum DefinitionKind<'db> {
AugmentedAssignment(AstNodeRef<ast::StmtAugAssign>),
For(ForStmtDefinitionKind),
Comprehension(ComprehensionDefinitionKind),
Parameter(AstNodeRef<ast::Parameter>),
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
VariadicPositionalParameter(AstNodeRef<ast::Parameter>),
VariadicKeywordParameter(AstNodeRef<ast::Parameter>),
Parameter(AstNodeRef<ast::ParameterWithDefault>),
WithItem(WithItemDefinitionKind),
MatchPattern(MatchPatternDefinitionKind),
ExceptHandler(ExceptHandlerDefinitionKind),
@ -475,7 +478,8 @@ impl DefinitionKind<'_> {
| DefinitionKind::ParamSpec(_)
| DefinitionKind::TypeVarTuple(_) => DefinitionCategory::DeclarationAndBinding,
// a parameter always binds a value, but is only a declaration if annotated
DefinitionKind::Parameter(parameter) => {
DefinitionKind::VariadicPositionalParameter(parameter)
| DefinitionKind::VariadicKeywordParameter(parameter) => {
if parameter.annotation.is_some() {
DefinitionCategory::DeclarationAndBinding
} else {
@ -483,7 +487,7 @@ impl DefinitionKind<'_> {
}
}
// presence of a default is irrelevant, same logic as for a no-default parameter
DefinitionKind::ParameterWithDefault(parameter_with_default) => {
DefinitionKind::Parameter(parameter_with_default) => {
if parameter_with_default.parameter.annotation.is_some() {
DefinitionCategory::DeclarationAndBinding
} else {
@ -743,6 +747,15 @@ impl From<&ast::ParameterWithDefault> for DefinitionNodeKey {
}
}
impl From<ast::AnyParameterRef<'_>> for DefinitionNodeKey {
fn from(value: ast::AnyParameterRef) -> Self {
Self(match value {
ast::AnyParameterRef::Variadic(node) => NodeKey::from_node(node),
ast::AnyParameterRef::NonVariadic(node) => NodeKey::from_node(node),
})
}
}
impl From<&ast::Identifier> for DefinitionNodeKey {
fn from(identifier: &ast::Identifier) -> Self {
Self(NodeKey::from_node(identifier))

View file

@ -225,6 +225,18 @@ fn definition_expression_ty<'db>(
}
}
/// Get the type of an expression from an arbitrary scope.
///
/// Can cause query cycles if used carelessly; caller must be sure that type inference isn't
/// currently in progress for the expression's scope.
fn expression_ty<'db>(db: &'db dyn Db, file: File, expression: &ast::Expr) -> Type<'db> {
let index = semantic_index(db, file);
let file_scope = index.expression_scope_id(expression);
let scope = file_scope.to_scope_id(db, file);
let expr_id = expression.scoped_expression_id(db, scope);
infer_scope_types(db, scope).expression_ty(expr_id)
}
/// Infer the combined type of an iterator of bindings.
///
/// Will return a union if there is more than one binding.

View file

@ -63,6 +63,7 @@ use crate::unpack::Unpack;
use crate::util::subscript::{PyIndex, PySlice};
use crate::Db;
use super::expression_ty;
use super::string_annotation::parse_string_annotation;
/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope.
@ -654,11 +655,14 @@ impl<'db> TypeInferenceBuilder<'db> {
definition,
);
}
DefinitionKind::Parameter(parameter) => {
self.infer_parameter_definition(parameter, definition);
DefinitionKind::VariadicPositionalParameter(parameter) => {
self.infer_variadic_positional_parameter_definition(parameter, definition);
}
DefinitionKind::ParameterWithDefault(parameter_with_default) => {
self.infer_parameter_with_default_definition(parameter_with_default, definition);
DefinitionKind::VariadicKeywordParameter(parameter) => {
self.infer_variadic_keyword_parameter_definition(parameter, definition);
}
DefinitionKind::Parameter(parameter_with_default) => {
self.infer_parameter_definition(parameter_with_default, definition);
}
DefinitionKind::WithItem(with_item) => {
self.infer_with_item_definition(
@ -871,6 +875,12 @@ impl<'db> TypeInferenceBuilder<'db> {
}
fn infer_function_body(&mut self, function: &ast::StmtFunctionDef) {
// Parameters are odd: they are Definitions in the function body scope, but have no
// constituent nodes that are part of the function body. In order to get diagnostics
// merged/emitted for them, we need to explicitly infer their definitions here.
for parameter in &function.parameters {
self.infer_definition(parameter);
}
self.infer_body(&function.body);
}
@ -1033,33 +1043,126 @@ impl<'db> TypeInferenceBuilder<'db> {
);
}
fn infer_parameter_with_default_definition(
/// Set initial declared type (if annotated) and inferred type for a function-parameter symbol,
/// in the function body scope.
///
/// The declared type is the annotated type, if any, or `Unknown`.
///
/// The inferred type is the annotated type, unioned with the type of the default value, if
/// any. If both types are fully static, this union is a no-op (it should simplify to just the
/// annotated type.) But in a case like `f(x=None)` with no annotated type, we want to infer
/// the type `Unknown | None` for `x`, not just `Unknown`, so that we can error on usage of `x`
/// that would not be valid for `None`.
///
/// If the default-value type is not assignable to the declared (annotated) type, we ignore the
/// default-value type and just infer the annotated type; this is the same way we handle
/// assignments, and allows an explicit annotation to override a bad inference.
///
/// Parameter definitions are odd in that they define a symbol in the function-body scope, so
/// the Definition belongs to the function body scope, but the expressions (annotation and
/// default value) both belong to outer scopes. (The default value always belongs to the outer
/// scope in which the function is defined, the annotation belongs either to the outer scope,
/// or maybe to an intervening type-params scope, if it's a generic function.) So we don't use
/// `self.infer_expression` or store any expression types here, we just use `expression_ty` to
/// get the types of the expressions from their respective scopes.
///
/// It is safe (non-cycle-causing) to use `expression_ty` here, because an outer scope can't
/// depend on a definition from an inner scope, so we shouldn't be in-process of inferring the
/// outer scope here.
fn infer_parameter_definition(
&mut self,
parameter_with_default: &ast::ParameterWithDefault,
definition: Definition<'db>,
) {
// TODO(dhruvmanila): Infer types from annotation or default expression
// TODO check that default is assignable to parameter type
self.infer_parameter_definition(&parameter_with_default.parameter, definition);
let ast::ParameterWithDefault {
parameter,
default,
range: _,
} = parameter_with_default;
let default_ty = default
.as_ref()
.map(|default| expression_ty(self.db, self.file, default));
if let Some(annotation) = parameter.annotation.as_ref() {
let declared_ty = expression_ty(self.db, self.file, annotation);
let inferred_ty = if let Some(default_ty) = default_ty {
if default_ty.is_assignable_to(self.db, declared_ty) {
UnionType::from_elements(self.db, [declared_ty, default_ty])
} else {
self.diagnostics.add(
parameter_with_default.into(),
"invalid-parameter-default",
format_args!(
"Default value of type `{}` is not assignable to annotated parameter type `{}`",
default_ty.display(self.db), declared_ty.display(self.db))
);
declared_ty
}
} else {
declared_ty
};
self.add_declaration_with_binding(
parameter.into(),
definition,
declared_ty,
inferred_ty,
);
} else {
let ty = if let Some(default_ty) = default_ty {
UnionType::from_elements(self.db, [Type::Unknown, default_ty])
} else {
Type::Unknown
};
self.add_binding(parameter.into(), definition, ty);
}
}
fn infer_parameter_definition(
/// Set initial declared/inferred types for a `*args` variadic positional parameter.
///
/// The annotated type is implicitly wrapped in a homogeneous tuple.
///
/// See `infer_parameter_definition` doc comment for some relevant observations about scopes.
fn infer_variadic_positional_parameter_definition(
&mut self,
parameter: &ast::Parameter,
definition: Definition<'db>,
) {
// TODO(dhruvmanila): Annotation expression is resolved at the enclosing scope, infer the
// parameter type from there
let annotated_ty = todo_type!("function parameter type");
if parameter.annotation.is_some() {
self.add_declaration_with_binding(
if let Some(annotation) = parameter.annotation.as_ref() {
let _annotated_ty = expression_ty(self.db, self.file, annotation);
// TODO `tuple[annotated_ty, ...]`
let ty = KnownClass::Tuple.to_instance(self.db);
self.add_declaration_with_binding(parameter.into(), definition, ty, ty);
} else {
self.add_binding(
parameter.into(),
definition,
annotated_ty,
annotated_ty,
// TODO `tuple[Unknown, ...]`
KnownClass::Tuple.to_instance(self.db),
);
}
}
/// Set initial declared/inferred types for a `*args` variadic positional parameter.
///
/// The annotated type is implicitly wrapped in a string-keyed dictionary.
///
/// See `infer_parameter_definition` doc comment for some relevant observations about scopes.
fn infer_variadic_keyword_parameter_definition(
&mut self,
parameter: &ast::Parameter,
definition: Definition<'db>,
) {
if let Some(annotation) = parameter.annotation.as_ref() {
let _annotated_ty = expression_ty(self.db, self.file, annotation);
// TODO `dict[str, annotated_ty]`
let ty = KnownClass::Dict.to_instance(self.db);
self.add_declaration_with_binding(parameter.into(), definition, ty, ty);
} else {
self.add_binding(parameter.into(), definition, annotated_ty);
self.add_binding(
parameter.into(),
definition,
// TODO `dict[str, Unknown]`
KnownClass::Dict.to_instance(self.db),
);
}
}
@ -1435,10 +1538,10 @@ impl<'db> TypeInferenceBuilder<'db> {
Type::Tuple(tuple) => UnionType::from_elements(
self.db,
tuple.elements(self.db).iter().map(|ty| {
ty.into_class_literal()
.map_or(todo_type!(), |ClassLiteralType { class }| {
Type::instance(class)
})
ty.into_class_literal().map_or(
todo_type!("exception type"),
|ClassLiteralType { class }| Type::instance(class),
)
}),
),
_ => todo_type!("exception type"),
@ -2719,7 +2822,7 @@ impl<'db> TypeInferenceBuilder<'db> {
.unwrap_with_diagnostic(value.as_ref().into(), &mut self.diagnostics);
// TODO
todo_type!()
todo_type!("starred expression")
}
fn infer_yield_expression(&mut self, yield_expression: &ast::ExprYield) -> Type<'db> {