diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index d5ef1b9fea..6b7752df6a 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -3201,16 +3201,16 @@ impl KnownClass { } } - /// Evaluate a call to this known class, and emit any diagnostics that are necessary - /// as a result of the call. + /// Evaluate a call to this known class, emit any diagnostics that are necessary + /// as a result of the call, and return the type that results from the call. pub(super) fn check_call<'db>( self, context: &InferContext<'db, '_>, index: &SemanticIndex<'db>, - overload_binding: &mut Binding<'db>, + overload_binding: &Binding<'db>, call_argument_types: &CallArgumentTypes<'_, 'db>, call_expression: &ast::ExprCall, - ) { + ) -> Option> { let db = context.db(); let scope = context.scope(); let module = context.module(); @@ -3226,10 +3226,9 @@ impl KnownClass { let Some(enclosing_class) = nearest_enclosing_class(db, index, scope, module) else { - overload_binding.set_return_type(Type::unknown()); BoundSuperError::UnavailableImplicitArguments .report_diagnostic(context, call_expression.into()); - return; + return Some(Type::unknown()); }; // The type of the first parameter if the given scope is function-like (i.e. function or lambda). @@ -3249,10 +3248,9 @@ impl KnownClass { }; let Some(first_param) = first_param else { - overload_binding.set_return_type(Type::unknown()); BoundSuperError::UnavailableImplicitArguments .report_diagnostic(context, call_expression.into()); - return; + return Some(Type::unknown()); }; let definition = index.expect_single_definition(first_param); @@ -3269,7 +3267,7 @@ impl KnownClass { Type::unknown() }); - overload_binding.set_return_type(bound_super); + Some(bound_super) } [Some(pivot_class_type), Some(owner_type)] => { let bound_super = BoundSuperType::build(db, *pivot_class_type, *owner_type) @@ -3278,9 +3276,9 @@ impl KnownClass { Type::unknown() }); - overload_binding.set_return_type(bound_super); + Some(bound_super) } - _ => {} + _ => None, } } @@ -3295,14 +3293,12 @@ impl KnownClass { _ => None, } }) else { - if let Some(builder) = - context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression) - { - builder.into_diagnostic( - "A legacy `typing.TypeVar` must be immediately assigned to a variable", - ); - } - return; + let builder = + context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?; + builder.into_diagnostic( + "A legacy `typing.TypeVar` must be immediately assigned to a variable", + ); + return None; }; let [ @@ -3315,7 +3311,7 @@ impl KnownClass { _infer_variance, ] = overload_binding.parameter_types() else { - return; + return None; }; let covariant = covariant @@ -3328,39 +3324,30 @@ impl KnownClass { let variance = match (contravariant, covariant) { (Truthiness::Ambiguous, _) => { - let Some(builder) = - context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression) - else { - return; - }; + let builder = + context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?; builder.into_diagnostic( "The `contravariant` parameter of a legacy `typing.TypeVar` \ cannot have an ambiguous value", ); - return; + return None; } (_, Truthiness::Ambiguous) => { - let Some(builder) = - context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression) - else { - return; - }; + let builder = + context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?; builder.into_diagnostic( "The `covariant` parameter of a legacy `typing.TypeVar` \ cannot have an ambiguous value", ); - return; + return None; } (Truthiness::AlwaysTrue, Truthiness::AlwaysTrue) => { - let Some(builder) = - context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression) - else { - return; - }; + let builder = + context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?; builder.into_diagnostic( "A legacy `typing.TypeVar` cannot be both covariant and contravariant", ); - return; + return None; } (Truthiness::AlwaysTrue, Truthiness::AlwaysFalse) => { TypeVarVariance::Contravariant @@ -3374,11 +3361,8 @@ impl KnownClass { let name_param = name_param.into_string_literal().map(|name| name.value(db)); if name_param.is_none_or(|name_param| name_param != target.id) { - let Some(builder) = - context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression) - else { - return; - }; + let builder = + context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?; builder.into_diagnostic(format_args!( "The name of a legacy `typing.TypeVar`{} must match \ the name of the variable it is assigned to (`{}`)", @@ -3389,7 +3373,7 @@ impl KnownClass { }, target.id, )); - return; + return None; } let bound_or_constraint = match (bound, constraints) { @@ -3414,13 +3398,13 @@ impl KnownClass { // TODO: Emit a diagnostic that TypeVar cannot be both bounded and // constrained - (Some(_), Some(_)) => return, + (Some(_), Some(_)) => return None, (None, None) => None, }; let containing_assignment = index.expect_single_definition(target); - overload_binding.set_return_type(Type::KnownInstance(KnownInstanceType::TypeVar( + Some(Type::KnownInstance(KnownInstanceType::TypeVar( TypeVarInstance::new( db, target.id.clone(), @@ -3430,7 +3414,7 @@ impl KnownClass { *default, TypeVarKind::Legacy, ), - ))); + ))) } KnownClass::TypeAliasType => { @@ -3446,30 +3430,31 @@ impl KnownClass { }); let [Some(name), Some(value), ..] = overload_binding.parameter_types() else { - return; + return None; }; - if let Some(name) = name.into_string_literal() { - overload_binding.set_return_type(Type::KnownInstance( - KnownInstanceType::TypeAliasType(TypeAliasType::Bare( + name.into_string_literal() + .map(|name| { + Type::KnownInstance(KnownInstanceType::TypeAliasType(TypeAliasType::Bare( BareTypeAliasType::new( db, ast::name::Name::new(name.value(db)), containing_assignment, value, ), - )), - )); - } else if let Some(builder) = - context.report_lint(&INVALID_TYPE_ALIAS_TYPE, call_expression) - { - builder.into_diagnostic( - "The name of a `typing.TypeAlias` must be a string literal", - ); - } + ))) + }) + .or_else(|| { + let builder = + context.report_lint(&INVALID_TYPE_ALIAS_TYPE, call_expression)?; + builder.into_diagnostic( + "The name of a `typing.TypeAlias` must be a string literal", + ); + None + }) } - _ => {} + _ => None, } } } diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 523a09123b..0a21ae7f5a 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -74,8 +74,7 @@ use crate::types::generics::GenericContext; use crate::types::narrow::ClassInfoConstraintFunction; use crate::types::signatures::{CallableSignature, Signature}; use crate::types::{ - Binding, BoundMethodType, CallableType, DynamicType, Type, TypeMapping, TypeRelation, - TypeVarInstance, + BoundMethodType, CallableType, DynamicType, Type, TypeMapping, TypeRelation, TypeVarInstance, }; use crate::{Db, FxOrderSet}; @@ -963,14 +962,14 @@ impl KnownFunction { pub(super) fn check_call( self, context: &InferContext, - overload_binding: &mut Binding, + parameter_types: &[Option>], call_expression: &ast::ExprCall, ) { let db = context.db(); match self { KnownFunction::RevealType => { - let [Some(revealed_type)] = overload_binding.parameter_types() else { + let [Some(revealed_type)] = parameter_types else { return; }; let Some(builder) = @@ -986,8 +985,7 @@ impl KnownFunction { ); } KnownFunction::AssertType => { - let [Some(actual_ty), Some(asserted_ty)] = overload_binding.parameter_types() - else { + let [Some(actual_ty), Some(asserted_ty)] = parameter_types else { return; }; @@ -1019,7 +1017,7 @@ impl KnownFunction { )); } KnownFunction::AssertNever => { - let [Some(actual_ty)] = overload_binding.parameter_types() else { + let [Some(actual_ty)] = parameter_types else { return; }; if actual_ty.is_equivalent_to(db, Type::Never) { @@ -1045,7 +1043,7 @@ impl KnownFunction { )); } KnownFunction::StaticAssert => { - let [Some(parameter_ty), message] = overload_binding.parameter_types() else { + let [Some(parameter_ty), message] = parameter_types else { return; }; let truthiness = match parameter_ty.try_bool(db) { @@ -1100,8 +1098,7 @@ impl KnownFunction { } } KnownFunction::Cast => { - let [Some(casted_type), Some(source_type)] = overload_binding.parameter_types() - else { + let [Some(casted_type), Some(source_type)] = parameter_types else { return; }; let contains_unknown_or_todo = @@ -1121,7 +1118,7 @@ impl KnownFunction { } } KnownFunction::GetProtocolMembers => { - let [Some(Type::ClassLiteral(class))] = overload_binding.parameter_types() else { + let [Some(Type::ClassLiteral(class))] = parameter_types else { return; }; if class.is_protocol(db) { @@ -1130,8 +1127,7 @@ impl KnownFunction { report_bad_argument_to_get_protocol_members(context, call_expression, *class); } KnownFunction::IsInstance | KnownFunction::IsSubclass => { - let [_, Some(Type::ClassLiteral(class))] = overload_binding.parameter_types() - else { + let [_, Some(Type::ClassLiteral(class))] = parameter_types else { return; }; let Some(protocol_class) = class.into_protocol_class(db) else { diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index f464325d02..c8c03ce000 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -5397,7 +5397,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { match binding_type { Type::FunctionLiteral(function_literal) => { if let Some(known_function) = function_literal.known(self.db()) { - known_function.check_call(&self.context, overload, call_expression); + known_function.check_call( + &self.context, + overload.parameter_types(), + call_expression, + ); } } @@ -5405,13 +5409,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let Some(known_class) = class.known(self.db()) else { continue; }; - known_class.check_call( + let overridden_return = known_class.check_call( &self.context, self.index, overload, &call_argument_types, call_expression, ); + if let Some(overridden_return) = overridden_return { + overload.set_return_type(overridden_return); + } } _ => {} }