[red-knot] Refactor KnownFunction::takes_expression_arguments() (#15406)

This commit is contained in:
Alex Waygood 2025-01-10 19:09:03 +00:00 committed by GitHub
parent 12f86f39a4
commit c82932e580
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 89 additions and 38 deletions

View file

@ -3428,37 +3428,90 @@ impl KnownFunction {
} }
} }
/// Returns a `u32` bitmask specifying whether or not /// Return the [`ParameterExpectations`] for this function.
/// arguments given to a particular function const fn parameter_expectations(self) -> ParameterExpectations {
/// should be interpreted as type expressions or value expressions.
///
/// The argument is treated as a type expression
/// when the corresponding bit is `1`.
/// The least-significant (right-most) bit corresponds to
/// the argument at the index 0 and so on.
///
/// For example, `assert_type()` has the bitmask value of `0b10`.
/// This means the second argument is a type expression and the first a value expression.
const fn takes_type_expression_arguments(self) -> u32 {
const ALL_VALUES: u32 = 0b0;
const SINGLE_TYPE: u32 = 0b1;
const TYPE_TYPE: u32 = 0b11;
const VALUE_TYPE: u32 = 0b10;
match self { match self {
KnownFunction::IsEquivalentTo => TYPE_TYPE, Self::IsFullyStatic | Self::IsSingleton | Self::IsSingleValued => {
KnownFunction::IsSubtypeOf => TYPE_TYPE, ParameterExpectations::SingleTypeExpression
KnownFunction::IsAssignableTo => TYPE_TYPE, }
KnownFunction::IsDisjointFrom => TYPE_TYPE,
KnownFunction::IsFullyStatic => SINGLE_TYPE, Self::IsEquivalentTo
KnownFunction::IsSingleton => SINGLE_TYPE, | Self::IsSubtypeOf
KnownFunction::IsSingleValued => SINGLE_TYPE, | Self::IsAssignableTo
KnownFunction::AssertType => VALUE_TYPE, | Self::IsDisjointFrom => ParameterExpectations::TwoTypeExpressions,
_ => ALL_VALUES,
Self::AssertType => ParameterExpectations::ValueExpressionAndTypeExpression,
Self::ConstraintFunction(_)
| Self::Len
| Self::Final
| Self::NoTypeCheck
| Self::RevealType
| Self::StaticAssert => ParameterExpectations::AllValueExpressions,
} }
} }
} }
/// Describes whether the parameters in a function expect value expressions or type expressions.
///
/// Whether a specific parameter in the function expects a type expression can be queried
/// using [`ParameterExpectations::expectation_at_index`].
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
enum ParameterExpectations {
/// All parameters in the function expect value expressions
#[default]
AllValueExpressions,
/// The first parameter in the function expects a type expression
SingleTypeExpression,
/// The first two parameters in the function expect type expressions
TwoTypeExpressions,
/// The first parameter in the function expects a value expression,
/// and the second expects a type expression
ValueExpressionAndTypeExpression,
}
impl ParameterExpectations {
/// Query whether the parameter at `parameter_index` expects a value expression or a type expression
fn expectation_at_index(self, parameter_index: usize) -> ParameterExpectation {
match self {
Self::AllValueExpressions => ParameterExpectation::ValueExpression,
Self::SingleTypeExpression => {
if parameter_index == 0 {
ParameterExpectation::TypeExpression
} else {
ParameterExpectation::ValueExpression
}
}
Self::TwoTypeExpressions => {
if parameter_index < 2 {
ParameterExpectation::TypeExpression
} else {
ParameterExpectation::ValueExpression
}
}
Self::ValueExpressionAndTypeExpression => {
if parameter_index == 1 {
ParameterExpectation::TypeExpression
} else {
ParameterExpectation::ValueExpression
}
}
}
}
}
/// Whether a single parameter in a given function expects a value expression or a [type expression]
///
/// [type expression]: https://typing.readthedocs.io/en/latest/spec/annotations.html#type-and-annotation-expressions
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
enum ParameterExpectation {
/// The parameter expects a value expression
#[default]
ValueExpression,
/// The parameter expects a type expression
TypeExpression,
}
#[salsa::interned] #[salsa::interned]
pub struct ModuleLiteralType<'db> { pub struct ModuleLiteralType<'db> {
/// The file in which this module was imported. /// The file in which this module was imported.

View file

@ -83,6 +83,7 @@ use super::slots::check_class_slots;
use super::string_annotation::{ use super::string_annotation::{
parse_string_annotation, BYTE_STRING_TYPE_ANNOTATION, FSTRING_TYPE_ANNOTATION, parse_string_annotation, BYTE_STRING_TYPE_ANNOTATION, FSTRING_TYPE_ANNOTATION,
}; };
use super::{ParameterExpectation, ParameterExpectations};
/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. /// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope.
/// Use when checking a scope, or needing to provide a type for an arbitrary expression in the /// Use when checking a scope, or needing to provide a type for an arbitrary expression in the
@ -956,7 +957,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_type_parameters(type_params); self.infer_type_parameters(type_params);
if let Some(arguments) = class.arguments.as_deref() { if let Some(arguments) = class.arguments.as_deref() {
self.infer_arguments(arguments, 0b0); self.infer_arguments(arguments, ParameterExpectations::default());
} }
} }
@ -2601,18 +2602,15 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_arguments<'a>( fn infer_arguments<'a>(
&mut self, &mut self,
arguments: &'a ast::Arguments, arguments: &'a ast::Arguments,
infer_as_type_expressions: u32, parameter_expectations: ParameterExpectations,
) -> CallArguments<'a, 'db> { ) -> CallArguments<'a, 'db> {
arguments arguments
.arguments_source_order() .arguments_source_order()
.enumerate() .enumerate()
.map(|(index, arg_or_keyword)| { .map(|(index, arg_or_keyword)| {
let infer_argument_type = if index < u32::BITS as usize let infer_argument_type = match parameter_expectations.expectation_at_index(index) {
&& infer_as_type_expressions & (1 << index) != 0 ParameterExpectation::TypeExpression => Self::infer_type_expression,
{ ParameterExpectation::ValueExpression => Self::infer_expression,
Self::infer_type_expression
} else {
Self::infer_expression
}; };
match arg_or_keyword { match arg_or_keyword {
@ -3157,13 +3155,13 @@ impl<'db> TypeInferenceBuilder<'db> {
let function_type = self.infer_expression(func); let function_type = self.infer_expression(func);
let infer_arguments_as_type_expressions = function_type let parameter_expectations = function_type
.into_function_literal() .into_function_literal()
.and_then(|f| f.known(self.db())) .and_then(|f| f.known(self.db()))
.map(KnownFunction::takes_type_expression_arguments) .map(KnownFunction::parameter_expectations)
.unwrap_or(0b0); .unwrap_or_default();
let call_arguments = self.infer_arguments(arguments, infer_arguments_as_type_expressions); let call_arguments = self.infer_arguments(arguments, parameter_expectations);
function_type function_type
.call(self.db(), &call_arguments) .call(self.db(), &call_arguments)
.unwrap_with_diagnostic(&self.context, call_expression.into()) .unwrap_with_diagnostic(&self.context, call_expression.into())