[red-knot] function parameter types (#14802)

## Summary

Inferred and declared types for function parameters, in the function
body scope.

Fixes #13693.

## Test Plan

Added mdtests.

---------

Co-authored-by: Micha Reiser <micha@reiser.io>
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Carl Meyer 2024-12-06 12:55:56 -08:00 committed by GitHub
parent 2119dcab6f
commit 3017b3b687
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 340 additions and 97 deletions

View file

@ -8,8 +8,8 @@ from typing_extensions import TypeVarTuple
Ts = TypeVarTuple("Ts")
def append_int(*args: *Ts) -> tuple[*Ts, int]:
# TODO: should show some representation of the variadic generic type
reveal_type(args) # revealed: @Todo(function parameter type)
# TODO: tuple[*Ts]
reveal_type(args) # revealed: tuple
return (*args, 1)

View file

@ -0,0 +1,75 @@
# Function parameter types
Within a function scope, the declared type of each parameter is its annotated type (or Unknown if
not annotated). The initial inferred type is the union of the declared type with the type of the
default value expression (if any). If both are fully static types, this union should simplify to the
annotated type (since the default value type must be assignable to the annotated type, and for fully
static types this means subtype-of, which simplifies in unions). But if the annotated type is
Unknown or another non-fully-static type, the default value type may still be relevant as lower
bound.
The variadic parameter is a variadic tuple of its annotated type; the variadic-keywords parameter is
a dictionary from strings to its annotated type.
## Parameter kinds
```py
from typing import Literal
def f(a, b: int, c=1, d: int = 2, /, e=3, f: Literal[4] = 4, *args: object, g=5, h: Literal[6] = 6, **kwargs: str):
reveal_type(a) # revealed: Unknown
reveal_type(b) # revealed: int
reveal_type(c) # revealed: Unknown | Literal[1]
reveal_type(d) # revealed: int
reveal_type(e) # revealed: Unknown | Literal[3]
reveal_type(f) # revealed: Literal[4]
reveal_type(g) # revealed: Unknown | Literal[5]
reveal_type(h) # revealed: Literal[6]
# TODO: should be `tuple[object, ...]` (needs generics)
reveal_type(args) # revealed: tuple
# TODO: should be `dict[str, str]` (needs generics)
reveal_type(kwargs) # revealed: dict
```
## Unannotated variadic parameters
...are inferred as tuple of Unknown or dict from string to Unknown.
```py
def g(*args, **kwargs):
# TODO: should be `tuple[Unknown, ...]` (needs generics)
reveal_type(args) # revealed: tuple
# TODO: should be `dict[str, Unknown]` (needs generics)
reveal_type(kwargs) # revealed: dict
```
## Annotation is present but not a fully static type
The default value type should be a lower bound on the inferred type.
```py
from typing import Any
def f(x: Any = 1):
reveal_type(x) # revealed: Any | Literal[1]
```
## Default value type must be assignable to annotated type
The default value type must be assignable to the annotated type. If not, we emit a diagnostic, and
fall back to inferring the annotated type, ignoring the default value type.
```py
# error: [invalid-parameter-default]
def f(x: int = "foo"):
reveal_type(x) # revealed: int
# The check is assignable-to, not subtype-of, so this is fine:
from typing import Any
def g(x: Any = "foo"):
reveal_type(x) # revealed: Any | Literal["foo"]
```

View file

@ -606,24 +606,11 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
let function_table = index.symbol_table(function_scope_id);
assert_eq!(
names(&function_table),
vec!["a", "b", "c", "args", "d", "kwargs"],
vec!["a", "b", "c", "d", "args", "kwargs"],
);
let use_def = index.use_def_map(function_scope_id);
for name in ["a", "b", "c", "d"] {
let binding = use_def
.first_public_binding(
function_table
.symbol_id_by_name(name)
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
binding.kind(&db),
DefinitionKind::ParameterWithDefault(_)
));
}
for name in ["args", "kwargs"] {
let binding = use_def
.first_public_binding(
function_table
@ -633,6 +620,28 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
.unwrap();
assert!(matches!(binding.kind(&db), DefinitionKind::Parameter(_)));
}
let args_binding = use_def
.first_public_binding(
function_table
.symbol_id_by_name("args")
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
args_binding.kind(&db),
DefinitionKind::VariadicPositionalParameter(_)
));
let kwargs_binding = use_def
.first_public_binding(
function_table
.symbol_id_by_name("kwargs")
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
kwargs_binding.kind(&db),
DefinitionKind::VariadicKeywordParameter(_)
));
}
#[test]
@ -654,25 +663,38 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
let lambda_table = index.symbol_table(lambda_scope_id);
assert_eq!(
names(&lambda_table),
vec!["a", "b", "c", "args", "d", "kwargs"],
vec!["a", "b", "c", "d", "args", "kwargs"],
);
let use_def = index.use_def_map(lambda_scope_id);
for name in ["a", "b", "c", "d"] {
let binding = use_def
.first_public_binding(lambda_table.symbol_id_by_name(name).expect("symbol exists"))
.unwrap();
assert!(matches!(
binding.kind(&db),
DefinitionKind::ParameterWithDefault(_)
));
}
for name in ["args", "kwargs"] {
let binding = use_def
.first_public_binding(lambda_table.symbol_id_by_name(name).expect("symbol exists"))
.unwrap();
assert!(matches!(binding.kind(&db), DefinitionKind::Parameter(_)));
}
let args_binding = use_def
.first_public_binding(
lambda_table
.symbol_id_by_name("args")
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
args_binding.kind(&db),
DefinitionKind::VariadicPositionalParameter(_)
));
let kwargs_binding = use_def
.first_public_binding(
lambda_table
.symbol_id_by_name("kwargs")
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
kwargs_binding.kind(&db),
DefinitionKind::VariadicKeywordParameter(_)
));
}
/// Test case to validate that the comprehension scope is correctly identified and that the target

View file

@ -9,7 +9,7 @@ use ruff_index::IndexVec;
use ruff_python_ast as ast;
use ruff_python_ast::name::Name;
use ruff_python_ast::visitor::{walk_expr, walk_pattern, walk_stmt, Visitor};
use ruff_python_ast::{AnyParameterRef, BoolOp, Expr};
use ruff_python_ast::{BoolOp, Expr};
use crate::ast_node_ref::AstNodeRef;
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
@ -479,21 +479,35 @@ impl<'db> SemanticIndexBuilder<'db> {
self.pop_scope();
}
fn declare_parameter(&mut self, parameter: AnyParameterRef<'db>) {
let symbol = self.add_symbol(parameter.name().id().clone());
fn declare_parameters(&mut self, parameters: &'db ast::Parameters) {
for parameter in parameters.iter_non_variadic_params() {
self.declare_parameter(parameter);
}
if let Some(vararg) = parameters.vararg.as_ref() {
let symbol = self.add_symbol(vararg.name.id().clone());
self.add_definition(
symbol,
DefinitionNodeRef::VariadicPositionalParameter(vararg),
);
}
if let Some(kwarg) = parameters.kwarg.as_ref() {
let symbol = self.add_symbol(kwarg.name.id().clone());
self.add_definition(symbol, DefinitionNodeRef::VariadicKeywordParameter(kwarg));
}
}
fn declare_parameter(&mut self, parameter: &'db ast::ParameterWithDefault) {
let symbol = self.add_symbol(parameter.parameter.name.id().clone());
let definition = self.add_definition(symbol, parameter);
if let AnyParameterRef::NonVariadic(with_default) = parameter {
// Insert a mapping from the parameter to the same definition.
// This ensures that calling `HasTy::ty` on the inner parameter returns
// a valid type (and doesn't panic)
let existing_definition = self.definitions_by_node.insert(
DefinitionNodeRef::from(AnyParameterRef::Variadic(&with_default.parameter)).key(),
definition,
);
debug_assert_eq!(existing_definition, None);
}
// Insert a mapping from the inner Parameter node to the same definition.
// This ensures that calling `HasTy::ty` on the inner parameter returns
// a valid type (and doesn't panic)
let existing_definition = self
.definitions_by_node
.insert((&parameter.parameter).into(), definition);
debug_assert_eq!(existing_definition, None);
}
pub(super) fn build(mut self) -> SemanticIndex<'db> {
@ -556,34 +570,40 @@ where
fn visit_stmt(&mut self, stmt: &'ast ast::Stmt) {
match stmt {
ast::Stmt::FunctionDef(function_def) => {
for decorator in &function_def.decorator_list {
let ast::StmtFunctionDef {
decorator_list,
parameters,
type_params,
name,
returns,
body,
is_async: _,
range: _,
} = function_def;
for decorator in decorator_list {
self.visit_decorator(decorator);
}
self.with_type_params(
NodeWithScopeRef::FunctionTypeParameters(function_def),
function_def.type_params.as_deref(),
type_params.as_deref(),
|builder| {
builder.visit_parameters(&function_def.parameters);
if let Some(expr) = &function_def.returns {
builder.visit_annotation(expr);
builder.visit_parameters(parameters);
if let Some(returns) = returns {
builder.visit_annotation(returns);
}
builder.push_scope(NodeWithScopeRef::Function(function_def));
// Add symbols and definitions for the parameters to the function scope.
for parameter in &*function_def.parameters {
builder.declare_parameter(parameter);
}
builder.declare_parameters(parameters);
builder.visit_body(&function_def.body);
builder.visit_body(body);
builder.pop_scope()
},
);
// The default value of the parameters needs to be evaluated in the
// enclosing scope.
for default in function_def
.parameters
for default in parameters
.iter_non_variadic_params()
.filter_map(|param| param.default.as_deref())
{
@ -592,7 +612,7 @@ where
// The symbol for the function name itself has to be evaluated
// at the end to match the runtime evaluation of parameter defaults
// and return-type annotations.
let symbol = self.add_symbol(function_def.name.id.clone());
let symbol = self.add_symbol(name.id.clone());
self.add_definition(symbol, function_def);
}
ast::Stmt::ClassDef(class) => {
@ -1179,10 +1199,8 @@ where
self.push_scope(NodeWithScopeRef::Lambda(lambda));
// Add symbols and definitions for the parameters to the lambda scope.
if let Some(parameters) = &lambda.parameters {
for parameter in parameters {
self.declare_parameter(parameter);
}
if let Some(parameters) = lambda.parameters.as_ref() {
self.declare_parameters(parameters);
}
self.visit_expr(lambda.body.as_ref());

View file

@ -89,7 +89,9 @@ pub(crate) enum DefinitionNodeRef<'a> {
AnnotatedAssignment(&'a ast::StmtAnnAssign),
AugmentedAssignment(&'a ast::StmtAugAssign),
Comprehension(ComprehensionDefinitionNodeRef<'a>),
Parameter(ast::AnyParameterRef<'a>),
VariadicPositionalParameter(&'a ast::Parameter),
VariadicKeywordParameter(&'a ast::Parameter),
Parameter(&'a ast::ParameterWithDefault),
WithItem(WithItemDefinitionNodeRef<'a>),
MatchPattern(MatchPatternDefinitionNodeRef<'a>),
ExceptHandler(ExceptHandlerDefinitionNodeRef<'a>),
@ -188,8 +190,8 @@ impl<'a> From<ComprehensionDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
}
}
impl<'a> From<ast::AnyParameterRef<'a>> for DefinitionNodeRef<'a> {
fn from(node: ast::AnyParameterRef<'a>) -> Self {
impl<'a> From<&'a ast::ParameterWithDefault> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::ParameterWithDefault) -> Self {
Self::Parameter(node)
}
}
@ -315,14 +317,15 @@ impl<'db> DefinitionNodeRef<'db> {
first,
is_async,
}),
DefinitionNodeRef::Parameter(parameter) => match parameter {
ast::AnyParameterRef::Variadic(parameter) => {
DefinitionKind::Parameter(AstNodeRef::new(parsed, parameter))
}
ast::AnyParameterRef::NonVariadic(parameter) => {
DefinitionKind::ParameterWithDefault(AstNodeRef::new(parsed, parameter))
}
},
DefinitionNodeRef::VariadicPositionalParameter(parameter) => {
DefinitionKind::VariadicPositionalParameter(AstNodeRef::new(parsed, parameter))
}
DefinitionNodeRef::VariadicKeywordParameter(parameter) => {
DefinitionKind::VariadicKeywordParameter(AstNodeRef::new(parsed, parameter))
}
DefinitionNodeRef::Parameter(parameter) => {
DefinitionKind::Parameter(AstNodeRef::new(parsed, parameter))
}
DefinitionNodeRef::WithItem(WithItemDefinitionNodeRef {
node,
target,
@ -384,10 +387,9 @@ impl<'db> DefinitionNodeRef<'db> {
is_async: _,
}) => target.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(),
Self::Parameter(node) => match node {
ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
},
Self::VariadicPositionalParameter(node) => node.into(),
Self::VariadicKeywordParameter(node) => node.into(),
Self::Parameter(node) => node.into(),
Self::WithItem(WithItemDefinitionNodeRef {
node: _,
target,
@ -452,8 +454,9 @@ pub enum DefinitionKind<'db> {
AugmentedAssignment(AstNodeRef<ast::StmtAugAssign>),
For(ForStmtDefinitionKind),
Comprehension(ComprehensionDefinitionKind),
Parameter(AstNodeRef<ast::Parameter>),
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
VariadicPositionalParameter(AstNodeRef<ast::Parameter>),
VariadicKeywordParameter(AstNodeRef<ast::Parameter>),
Parameter(AstNodeRef<ast::ParameterWithDefault>),
WithItem(WithItemDefinitionKind),
MatchPattern(MatchPatternDefinitionKind),
ExceptHandler(ExceptHandlerDefinitionKind),
@ -475,7 +478,8 @@ impl DefinitionKind<'_> {
| DefinitionKind::ParamSpec(_)
| DefinitionKind::TypeVarTuple(_) => DefinitionCategory::DeclarationAndBinding,
// a parameter always binds a value, but is only a declaration if annotated
DefinitionKind::Parameter(parameter) => {
DefinitionKind::VariadicPositionalParameter(parameter)
| DefinitionKind::VariadicKeywordParameter(parameter) => {
if parameter.annotation.is_some() {
DefinitionCategory::DeclarationAndBinding
} else {
@ -483,7 +487,7 @@ impl DefinitionKind<'_> {
}
}
// presence of a default is irrelevant, same logic as for a no-default parameter
DefinitionKind::ParameterWithDefault(parameter_with_default) => {
DefinitionKind::Parameter(parameter_with_default) => {
if parameter_with_default.parameter.annotation.is_some() {
DefinitionCategory::DeclarationAndBinding
} else {
@ -743,6 +747,15 @@ impl From<&ast::ParameterWithDefault> for DefinitionNodeKey {
}
}
impl From<ast::AnyParameterRef<'_>> for DefinitionNodeKey {
fn from(value: ast::AnyParameterRef) -> Self {
Self(match value {
ast::AnyParameterRef::Variadic(node) => NodeKey::from_node(node),
ast::AnyParameterRef::NonVariadic(node) => NodeKey::from_node(node),
})
}
}
impl From<&ast::Identifier> for DefinitionNodeKey {
fn from(identifier: &ast::Identifier) -> Self {
Self(NodeKey::from_node(identifier))

View file

@ -225,6 +225,18 @@ fn definition_expression_ty<'db>(
}
}
/// Get the type of an expression from an arbitrary scope.
///
/// Can cause query cycles if used carelessly; caller must be sure that type inference isn't
/// currently in progress for the expression's scope.
fn expression_ty<'db>(db: &'db dyn Db, file: File, expression: &ast::Expr) -> Type<'db> {
let index = semantic_index(db, file);
let file_scope = index.expression_scope_id(expression);
let scope = file_scope.to_scope_id(db, file);
let expr_id = expression.scoped_expression_id(db, scope);
infer_scope_types(db, scope).expression_ty(expr_id)
}
/// Infer the combined type of an iterator of bindings.
///
/// Will return a union if there is more than one binding.

View file

@ -63,6 +63,7 @@ use crate::unpack::Unpack;
use crate::util::subscript::{PyIndex, PySlice};
use crate::Db;
use super::expression_ty;
use super::string_annotation::parse_string_annotation;
/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope.
@ -654,11 +655,14 @@ impl<'db> TypeInferenceBuilder<'db> {
definition,
);
}
DefinitionKind::Parameter(parameter) => {
self.infer_parameter_definition(parameter, definition);
DefinitionKind::VariadicPositionalParameter(parameter) => {
self.infer_variadic_positional_parameter_definition(parameter, definition);
}
DefinitionKind::ParameterWithDefault(parameter_with_default) => {
self.infer_parameter_with_default_definition(parameter_with_default, definition);
DefinitionKind::VariadicKeywordParameter(parameter) => {
self.infer_variadic_keyword_parameter_definition(parameter, definition);
}
DefinitionKind::Parameter(parameter_with_default) => {
self.infer_parameter_definition(parameter_with_default, definition);
}
DefinitionKind::WithItem(with_item) => {
self.infer_with_item_definition(
@ -871,6 +875,12 @@ impl<'db> TypeInferenceBuilder<'db> {
}
fn infer_function_body(&mut self, function: &ast::StmtFunctionDef) {
// Parameters are odd: they are Definitions in the function body scope, but have no
// constituent nodes that are part of the function body. In order to get diagnostics
// merged/emitted for them, we need to explicitly infer their definitions here.
for parameter in &function.parameters {
self.infer_definition(parameter);
}
self.infer_body(&function.body);
}
@ -1033,33 +1043,126 @@ impl<'db> TypeInferenceBuilder<'db> {
);
}
fn infer_parameter_with_default_definition(
/// Set initial declared type (if annotated) and inferred type for a function-parameter symbol,
/// in the function body scope.
///
/// The declared type is the annotated type, if any, or `Unknown`.
///
/// The inferred type is the annotated type, unioned with the type of the default value, if
/// any. If both types are fully static, this union is a no-op (it should simplify to just the
/// annotated type.) But in a case like `f(x=None)` with no annotated type, we want to infer
/// the type `Unknown | None` for `x`, not just `Unknown`, so that we can error on usage of `x`
/// that would not be valid for `None`.
///
/// If the default-value type is not assignable to the declared (annotated) type, we ignore the
/// default-value type and just infer the annotated type; this is the same way we handle
/// assignments, and allows an explicit annotation to override a bad inference.
///
/// Parameter definitions are odd in that they define a symbol in the function-body scope, so
/// the Definition belongs to the function body scope, but the expressions (annotation and
/// default value) both belong to outer scopes. (The default value always belongs to the outer
/// scope in which the function is defined, the annotation belongs either to the outer scope,
/// or maybe to an intervening type-params scope, if it's a generic function.) So we don't use
/// `self.infer_expression` or store any expression types here, we just use `expression_ty` to
/// get the types of the expressions from their respective scopes.
///
/// It is safe (non-cycle-causing) to use `expression_ty` here, because an outer scope can't
/// depend on a definition from an inner scope, so we shouldn't be in-process of inferring the
/// outer scope here.
fn infer_parameter_definition(
&mut self,
parameter_with_default: &ast::ParameterWithDefault,
definition: Definition<'db>,
) {
// TODO(dhruvmanila): Infer types from annotation or default expression
// TODO check that default is assignable to parameter type
self.infer_parameter_definition(&parameter_with_default.parameter, definition);
let ast::ParameterWithDefault {
parameter,
default,
range: _,
} = parameter_with_default;
let default_ty = default
.as_ref()
.map(|default| expression_ty(self.db, self.file, default));
if let Some(annotation) = parameter.annotation.as_ref() {
let declared_ty = expression_ty(self.db, self.file, annotation);
let inferred_ty = if let Some(default_ty) = default_ty {
if default_ty.is_assignable_to(self.db, declared_ty) {
UnionType::from_elements(self.db, [declared_ty, default_ty])
} else {
self.diagnostics.add(
parameter_with_default.into(),
"invalid-parameter-default",
format_args!(
"Default value of type `{}` is not assignable to annotated parameter type `{}`",
default_ty.display(self.db), declared_ty.display(self.db))
);
declared_ty
}
} else {
declared_ty
};
self.add_declaration_with_binding(
parameter.into(),
definition,
declared_ty,
inferred_ty,
);
} else {
let ty = if let Some(default_ty) = default_ty {
UnionType::from_elements(self.db, [Type::Unknown, default_ty])
} else {
Type::Unknown
};
self.add_binding(parameter.into(), definition, ty);
}
}
fn infer_parameter_definition(
/// Set initial declared/inferred types for a `*args` variadic positional parameter.
///
/// The annotated type is implicitly wrapped in a homogeneous tuple.
///
/// See `infer_parameter_definition` doc comment for some relevant observations about scopes.
fn infer_variadic_positional_parameter_definition(
&mut self,
parameter: &ast::Parameter,
definition: Definition<'db>,
) {
// TODO(dhruvmanila): Annotation expression is resolved at the enclosing scope, infer the
// parameter type from there
let annotated_ty = todo_type!("function parameter type");
if parameter.annotation.is_some() {
self.add_declaration_with_binding(
if let Some(annotation) = parameter.annotation.as_ref() {
let _annotated_ty = expression_ty(self.db, self.file, annotation);
// TODO `tuple[annotated_ty, ...]`
let ty = KnownClass::Tuple.to_instance(self.db);
self.add_declaration_with_binding(parameter.into(), definition, ty, ty);
} else {
self.add_binding(
parameter.into(),
definition,
annotated_ty,
annotated_ty,
// TODO `tuple[Unknown, ...]`
KnownClass::Tuple.to_instance(self.db),
);
}
}
/// Set initial declared/inferred types for a `*args` variadic positional parameter.
///
/// The annotated type is implicitly wrapped in a string-keyed dictionary.
///
/// See `infer_parameter_definition` doc comment for some relevant observations about scopes.
fn infer_variadic_keyword_parameter_definition(
&mut self,
parameter: &ast::Parameter,
definition: Definition<'db>,
) {
if let Some(annotation) = parameter.annotation.as_ref() {
let _annotated_ty = expression_ty(self.db, self.file, annotation);
// TODO `dict[str, annotated_ty]`
let ty = KnownClass::Dict.to_instance(self.db);
self.add_declaration_with_binding(parameter.into(), definition, ty, ty);
} else {
self.add_binding(parameter.into(), definition, annotated_ty);
self.add_binding(
parameter.into(),
definition,
// TODO `dict[str, Unknown]`
KnownClass::Dict.to_instance(self.db),
);
}
}
@ -1435,10 +1538,10 @@ impl<'db> TypeInferenceBuilder<'db> {
Type::Tuple(tuple) => UnionType::from_elements(
self.db,
tuple.elements(self.db).iter().map(|ty| {
ty.into_class_literal()
.map_or(todo_type!(), |ClassLiteralType { class }| {
Type::instance(class)
})
ty.into_class_literal().map_or(
todo_type!("exception type"),
|ClassLiteralType { class }| Type::instance(class),
)
}),
),
_ => todo_type!("exception type"),
@ -2719,7 +2822,7 @@ impl<'db> TypeInferenceBuilder<'db> {
.unwrap_with_diagnostic(value.as_ref().into(), &mut self.diagnostics);
// TODO
todo_type!()
todo_type!("starred expression")
}
fn infer_yield_expression(&mut self, yield_expression: &ast::ExprYield) -> Type<'db> {