[ty] Simplify KnownClass::check_call() and KnownFunction::check_call() (#18981)

This commit is contained in:
Alex Waygood 2025-06-27 12:23:29 +01:00 committed by GitHub
parent 3c18d85c7d
commit 57bd7d055d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 64 additions and 76 deletions

View file

@ -3201,16 +3201,16 @@ impl KnownClass {
} }
} }
/// Evaluate a call to this known class, and emit any diagnostics that are necessary /// Evaluate a call to this known class, emit any diagnostics that are necessary
/// as a result of the call. /// as a result of the call, and return the type that results from the call.
pub(super) fn check_call<'db>( pub(super) fn check_call<'db>(
self, self,
context: &InferContext<'db, '_>, context: &InferContext<'db, '_>,
index: &SemanticIndex<'db>, index: &SemanticIndex<'db>,
overload_binding: &mut Binding<'db>, overload_binding: &Binding<'db>,
call_argument_types: &CallArgumentTypes<'_, 'db>, call_argument_types: &CallArgumentTypes<'_, 'db>,
call_expression: &ast::ExprCall, call_expression: &ast::ExprCall,
) { ) -> Option<Type<'db>> {
let db = context.db(); let db = context.db();
let scope = context.scope(); let scope = context.scope();
let module = context.module(); let module = context.module();
@ -3226,10 +3226,9 @@ impl KnownClass {
let Some(enclosing_class) = let Some(enclosing_class) =
nearest_enclosing_class(db, index, scope, module) nearest_enclosing_class(db, index, scope, module)
else { else {
overload_binding.set_return_type(Type::unknown());
BoundSuperError::UnavailableImplicitArguments BoundSuperError::UnavailableImplicitArguments
.report_diagnostic(context, call_expression.into()); .report_diagnostic(context, call_expression.into());
return; return Some(Type::unknown());
}; };
// The type of the first parameter if the given scope is function-like (i.e. function or lambda). // The type of the first parameter if the given scope is function-like (i.e. function or lambda).
@ -3249,10 +3248,9 @@ impl KnownClass {
}; };
let Some(first_param) = first_param else { let Some(first_param) = first_param else {
overload_binding.set_return_type(Type::unknown());
BoundSuperError::UnavailableImplicitArguments BoundSuperError::UnavailableImplicitArguments
.report_diagnostic(context, call_expression.into()); .report_diagnostic(context, call_expression.into());
return; return Some(Type::unknown());
}; };
let definition = index.expect_single_definition(first_param); let definition = index.expect_single_definition(first_param);
@ -3269,7 +3267,7 @@ impl KnownClass {
Type::unknown() Type::unknown()
}); });
overload_binding.set_return_type(bound_super); Some(bound_super)
} }
[Some(pivot_class_type), Some(owner_type)] => { [Some(pivot_class_type), Some(owner_type)] => {
let bound_super = BoundSuperType::build(db, *pivot_class_type, *owner_type) let bound_super = BoundSuperType::build(db, *pivot_class_type, *owner_type)
@ -3278,9 +3276,9 @@ impl KnownClass {
Type::unknown() Type::unknown()
}); });
overload_binding.set_return_type(bound_super); Some(bound_super)
} }
_ => {} _ => None,
} }
} }
@ -3295,14 +3293,12 @@ impl KnownClass {
_ => None, _ => None,
} }
}) else { }) else {
if let Some(builder) = let builder =
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression) context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?;
{
builder.into_diagnostic( builder.into_diagnostic(
"A legacy `typing.TypeVar` must be immediately assigned to a variable", "A legacy `typing.TypeVar` must be immediately assigned to a variable",
); );
} return None;
return;
}; };
let [ let [
@ -3315,7 +3311,7 @@ impl KnownClass {
_infer_variance, _infer_variance,
] = overload_binding.parameter_types() ] = overload_binding.parameter_types()
else { else {
return; return None;
}; };
let covariant = covariant let covariant = covariant
@ -3328,39 +3324,30 @@ impl KnownClass {
let variance = match (contravariant, covariant) { let variance = match (contravariant, covariant) {
(Truthiness::Ambiguous, _) => { (Truthiness::Ambiguous, _) => {
let Some(builder) = let builder =
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression) context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?;
else {
return;
};
builder.into_diagnostic( builder.into_diagnostic(
"The `contravariant` parameter of a legacy `typing.TypeVar` \ "The `contravariant` parameter of a legacy `typing.TypeVar` \
cannot have an ambiguous value", cannot have an ambiguous value",
); );
return; return None;
} }
(_, Truthiness::Ambiguous) => { (_, Truthiness::Ambiguous) => {
let Some(builder) = let builder =
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression) context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?;
else {
return;
};
builder.into_diagnostic( builder.into_diagnostic(
"The `covariant` parameter of a legacy `typing.TypeVar` \ "The `covariant` parameter of a legacy `typing.TypeVar` \
cannot have an ambiguous value", cannot have an ambiguous value",
); );
return; return None;
} }
(Truthiness::AlwaysTrue, Truthiness::AlwaysTrue) => { (Truthiness::AlwaysTrue, Truthiness::AlwaysTrue) => {
let Some(builder) = let builder =
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression) context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?;
else {
return;
};
builder.into_diagnostic( builder.into_diagnostic(
"A legacy `typing.TypeVar` cannot be both covariant and contravariant", "A legacy `typing.TypeVar` cannot be both covariant and contravariant",
); );
return; return None;
} }
(Truthiness::AlwaysTrue, Truthiness::AlwaysFalse) => { (Truthiness::AlwaysTrue, Truthiness::AlwaysFalse) => {
TypeVarVariance::Contravariant TypeVarVariance::Contravariant
@ -3374,11 +3361,8 @@ impl KnownClass {
let name_param = name_param.into_string_literal().map(|name| name.value(db)); let name_param = name_param.into_string_literal().map(|name| name.value(db));
if name_param.is_none_or(|name_param| name_param != target.id) { if name_param.is_none_or(|name_param| name_param != target.id) {
let Some(builder) = let builder =
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression) context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?;
else {
return;
};
builder.into_diagnostic(format_args!( builder.into_diagnostic(format_args!(
"The name of a legacy `typing.TypeVar`{} must match \ "The name of a legacy `typing.TypeVar`{} must match \
the name of the variable it is assigned to (`{}`)", the name of the variable it is assigned to (`{}`)",
@ -3389,7 +3373,7 @@ impl KnownClass {
}, },
target.id, target.id,
)); ));
return; return None;
} }
let bound_or_constraint = match (bound, constraints) { let bound_or_constraint = match (bound, constraints) {
@ -3414,13 +3398,13 @@ impl KnownClass {
// TODO: Emit a diagnostic that TypeVar cannot be both bounded and // TODO: Emit a diagnostic that TypeVar cannot be both bounded and
// constrained // constrained
(Some(_), Some(_)) => return, (Some(_), Some(_)) => return None,
(None, None) => None, (None, None) => None,
}; };
let containing_assignment = index.expect_single_definition(target); let containing_assignment = index.expect_single_definition(target);
overload_binding.set_return_type(Type::KnownInstance(KnownInstanceType::TypeVar( Some(Type::KnownInstance(KnownInstanceType::TypeVar(
TypeVarInstance::new( TypeVarInstance::new(
db, db,
target.id.clone(), target.id.clone(),
@ -3430,7 +3414,7 @@ impl KnownClass {
*default, *default,
TypeVarKind::Legacy, TypeVarKind::Legacy,
), ),
))); )))
} }
KnownClass::TypeAliasType => { KnownClass::TypeAliasType => {
@ -3446,30 +3430,31 @@ impl KnownClass {
}); });
let [Some(name), Some(value), ..] = overload_binding.parameter_types() else { let [Some(name), Some(value), ..] = overload_binding.parameter_types() else {
return; return None;
}; };
if let Some(name) = name.into_string_literal() { name.into_string_literal()
overload_binding.set_return_type(Type::KnownInstance( .map(|name| {
KnownInstanceType::TypeAliasType(TypeAliasType::Bare( Type::KnownInstance(KnownInstanceType::TypeAliasType(TypeAliasType::Bare(
BareTypeAliasType::new( BareTypeAliasType::new(
db, db,
ast::name::Name::new(name.value(db)), ast::name::Name::new(name.value(db)),
containing_assignment, containing_assignment,
value, value,
), ),
)), )))
)); })
} else if let Some(builder) = .or_else(|| {
context.report_lint(&INVALID_TYPE_ALIAS_TYPE, call_expression) let builder =
{ context.report_lint(&INVALID_TYPE_ALIAS_TYPE, call_expression)?;
builder.into_diagnostic( builder.into_diagnostic(
"The name of a `typing.TypeAlias` must be a string literal", "The name of a `typing.TypeAlias` must be a string literal",
); );
} None
})
} }
_ => {} _ => None,
} }
} }
} }

View file

@ -74,8 +74,7 @@ use crate::types::generics::GenericContext;
use crate::types::narrow::ClassInfoConstraintFunction; use crate::types::narrow::ClassInfoConstraintFunction;
use crate::types::signatures::{CallableSignature, Signature}; use crate::types::signatures::{CallableSignature, Signature};
use crate::types::{ use crate::types::{
Binding, BoundMethodType, CallableType, DynamicType, Type, TypeMapping, TypeRelation, BoundMethodType, CallableType, DynamicType, Type, TypeMapping, TypeRelation, TypeVarInstance,
TypeVarInstance,
}; };
use crate::{Db, FxOrderSet}; use crate::{Db, FxOrderSet};
@ -963,14 +962,14 @@ impl KnownFunction {
pub(super) fn check_call( pub(super) fn check_call(
self, self,
context: &InferContext, context: &InferContext,
overload_binding: &mut Binding, parameter_types: &[Option<Type<'_>>],
call_expression: &ast::ExprCall, call_expression: &ast::ExprCall,
) { ) {
let db = context.db(); let db = context.db();
match self { match self {
KnownFunction::RevealType => { KnownFunction::RevealType => {
let [Some(revealed_type)] = overload_binding.parameter_types() else { let [Some(revealed_type)] = parameter_types else {
return; return;
}; };
let Some(builder) = let Some(builder) =
@ -986,8 +985,7 @@ impl KnownFunction {
); );
} }
KnownFunction::AssertType => { KnownFunction::AssertType => {
let [Some(actual_ty), Some(asserted_ty)] = overload_binding.parameter_types() let [Some(actual_ty), Some(asserted_ty)] = parameter_types else {
else {
return; return;
}; };
@ -1019,7 +1017,7 @@ impl KnownFunction {
)); ));
} }
KnownFunction::AssertNever => { KnownFunction::AssertNever => {
let [Some(actual_ty)] = overload_binding.parameter_types() else { let [Some(actual_ty)] = parameter_types else {
return; return;
}; };
if actual_ty.is_equivalent_to(db, Type::Never) { if actual_ty.is_equivalent_to(db, Type::Never) {
@ -1045,7 +1043,7 @@ impl KnownFunction {
)); ));
} }
KnownFunction::StaticAssert => { KnownFunction::StaticAssert => {
let [Some(parameter_ty), message] = overload_binding.parameter_types() else { let [Some(parameter_ty), message] = parameter_types else {
return; return;
}; };
let truthiness = match parameter_ty.try_bool(db) { let truthiness = match parameter_ty.try_bool(db) {
@ -1100,8 +1098,7 @@ impl KnownFunction {
} }
} }
KnownFunction::Cast => { KnownFunction::Cast => {
let [Some(casted_type), Some(source_type)] = overload_binding.parameter_types() let [Some(casted_type), Some(source_type)] = parameter_types else {
else {
return; return;
}; };
let contains_unknown_or_todo = let contains_unknown_or_todo =
@ -1121,7 +1118,7 @@ impl KnownFunction {
} }
} }
KnownFunction::GetProtocolMembers => { KnownFunction::GetProtocolMembers => {
let [Some(Type::ClassLiteral(class))] = overload_binding.parameter_types() else { let [Some(Type::ClassLiteral(class))] = parameter_types else {
return; return;
}; };
if class.is_protocol(db) { if class.is_protocol(db) {
@ -1130,8 +1127,7 @@ impl KnownFunction {
report_bad_argument_to_get_protocol_members(context, call_expression, *class); report_bad_argument_to_get_protocol_members(context, call_expression, *class);
} }
KnownFunction::IsInstance | KnownFunction::IsSubclass => { KnownFunction::IsInstance | KnownFunction::IsSubclass => {
let [_, Some(Type::ClassLiteral(class))] = overload_binding.parameter_types() let [_, Some(Type::ClassLiteral(class))] = parameter_types else {
else {
return; return;
}; };
let Some(protocol_class) = class.into_protocol_class(db) else { let Some(protocol_class) = class.into_protocol_class(db) else {

View file

@ -5397,7 +5397,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
match binding_type { match binding_type {
Type::FunctionLiteral(function_literal) => { Type::FunctionLiteral(function_literal) => {
if let Some(known_function) = function_literal.known(self.db()) { if let Some(known_function) = function_literal.known(self.db()) {
known_function.check_call(&self.context, overload, call_expression); known_function.check_call(
&self.context,
overload.parameter_types(),
call_expression,
);
} }
} }
@ -5405,13 +5409,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let Some(known_class) = class.known(self.db()) else { let Some(known_class) = class.known(self.db()) else {
continue; continue;
}; };
known_class.check_call( let overridden_return = known_class.check_call(
&self.context, &self.context,
self.index, self.index,
overload, overload,
&call_argument_types, &call_argument_types,
call_expression, call_expression,
); );
if let Some(overridden_return) = overridden_return {
overload.set_return_type(overridden_return);
}
} }
_ => {} _ => {}
} }