diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index fe44fe3b9d..980afbdf46 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -260,6 +260,11 @@ def f(cond: bool) -> int: +```toml +[environment] +python-version = "3.12" +``` + ```py # error: [invalid-return-type] def f() -> int: @@ -279,6 +284,18 @@ T = TypeVar("T") # error: [invalid-return-type] def m(x: T) -> T: ... + +class A[T]: ... + +def f() -> A[int]: + class A[T]: ... + return A[int]() # error: [invalid-return-type] + +class B: ... + +def g() -> B: + class B: ... + return B() # error: [invalid-return-type] ``` ## Invalid return type in stub file diff --git a/crates/ty_python_semantic/resources/mdtest/snapshots/return_type.md_-_Function_return_type_-_Invalid_return_type_(a91e0c67519cd77f).snap b/crates/ty_python_semantic/resources/mdtest/snapshots/return_type.md_-_Function_return_type_-_Invalid_return_type_(a91e0c67519cd77f).snap index 9c3379586a..e0f0371720 100644 --- a/crates/ty_python_semantic/resources/mdtest/snapshots/return_type.md_-_Function_return_type_-_Invalid_return_type_(a91e0c67519cd77f).snap +++ b/crates/ty_python_semantic/resources/mdtest/snapshots/return_type.md_-_Function_return_type_-_Invalid_return_type_(a91e0c67519cd77f).snap @@ -30,6 +30,18 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/function/return_type.md 16 | 17 | # error: [invalid-return-type] 18 | def m(x: T) -> T: ... +19 | +20 | class A[T]: ... +21 | +22 | def f() -> A[int]: +23 | class A[T]: ... +24 | return A[int]() # error: [invalid-return-type] +25 | +26 | class B: ... +27 | +28 | def g() -> B: +29 | class B: ... +30 | return B() # error: [invalid-return-type] ``` # Diagnostics @@ -91,9 +103,45 @@ error[invalid-return-type]: Function always implicitly returns `None`, which is 17 | # error: [invalid-return-type] 18 | def m(x: T) -> T: ... | ^ +19 | +20 | class A[T]: ... | info: Consider changing the return annotation to `-> None` or adding a `return` statement info: Only functions in stub files, methods on protocol classes, or methods with `@abstractmethod` are permitted to have empty bodies info: rule `invalid-return-type` is enabled by default ``` + +``` +error[invalid-return-type]: Return type does not match returned value + --> src/mdtest_snippet.py:22:12 + | +20 | class A[T]: ... +21 | +22 | def f() -> A[int]: + | ------ Expected `mdtest_snippet.A[int]` because of return type +23 | class A[T]: ... +24 | return A[int]() # error: [invalid-return-type] + | ^^^^^^^^ expected `mdtest_snippet.A[int]`, found `mdtest_snippet..A[int]` +25 | +26 | class B: ... + | +info: rule `invalid-return-type` is enabled by default + +``` + +``` +error[invalid-return-type]: Return type does not match returned value + --> src/mdtest_snippet.py:28:12 + | +26 | class B: ... +27 | +28 | def g() -> B: + | - Expected `mdtest_snippet.B` because of return type +29 | class B: ... +30 | return B() # error: [invalid-return-type] + | ^^^ expected `mdtest_snippet.B`, found `mdtest_snippet..B` + | +info: rule `invalid-return-type` is enabled by default + +``` diff --git a/crates/ty_python_semantic/src/types/diagnostic.rs b/crates/ty_python_semantic/src/types/diagnostic.rs index ca953928ed..a75ce8796f 100644 --- a/crates/ty_python_semantic/src/types/diagnostic.rs +++ b/crates/ty_python_semantic/src/types/diagnostic.rs @@ -10,7 +10,7 @@ use crate::semantic_index::SemanticIndex; use crate::semantic_index::definition::Definition; use crate::semantic_index::place::{PlaceTable, ScopedPlaceId}; use crate::suppression::FileSuppressionId; -use crate::types::class::{ClassType, DisjointBase, DisjointBaseKind, Field}; +use crate::types::class::{DisjointBase, DisjointBaseKind, Field}; use crate::types::function::KnownFunction; use crate::types::string_annotation::{ BYTE_STRING_TYPE_ANNOTATION, ESCAPE_CHARACTER_IN_FORWARD_ANNOTATION, FSTRING_TYPE_ANNOTATION, @@ -1945,18 +1945,8 @@ pub(super) fn report_invalid_assignment( target_ty: Type, source_ty: Type, ) { - let mut settings = DisplaySettings::default(); - // Handles the situation where the report naming is confusing, such as class with the same Name, - // but from different scopes. - if let Some(target_class) = type_to_class_literal(target_ty, context.db()) { - if let Some(source_class) = type_to_class_literal(source_ty, context.db()) { - if target_class != source_class - && target_class.name(context.db()) == source_class.name(context.db()) - { - settings = settings.qualified(); - } - } - } + let settings = + DisplaySettings::from_possibly_ambiguous_type_pair(context.db(), target_ty, source_ty); report_invalid_assignment_with_message( context, @@ -1970,36 +1960,6 @@ pub(super) fn report_invalid_assignment( ); } -// TODO: generalize this to a method that takes any two types, walks them recursively, and returns -// a set of types with ambiguous names whose display should be qualified. Then we can use this in -// any diagnostic that displays two types. -fn type_to_class_literal<'db>(ty: Type<'db>, db: &'db dyn crate::Db) -> Option> { - match ty { - Type::ClassLiteral(class) => Some(class), - Type::NominalInstance(instance) => match instance.class(db) { - crate::types::class::ClassType::NonGeneric(class) => Some(class), - crate::types::class::ClassType::Generic(alias) => Some(alias.origin(db)), - }, - Type::EnumLiteral(enum_literal) => Some(enum_literal.enum_class(db)), - Type::GenericAlias(alias) => Some(alias.origin(db)), - Type::ProtocolInstance(ProtocolInstanceType { - inner: Protocol::FromClass(class), - .. - }) => match class { - ClassType::NonGeneric(class) => Some(class), - ClassType::Generic(alias) => Some(alias.origin(db)), - }, - Type::TypedDict(typed_dict) => match typed_dict.defining_class() { - ClassType::NonGeneric(class) => Some(class), - ClassType::Generic(alias) => Some(alias.origin(db)), - }, - Type::SubclassOf(subclass_of) => { - type_to_class_literal(Type::from(subclass_of.subclass_of().into_class()?), db) - } - _ => None, - } -} - pub(super) fn report_invalid_attribute_assignment( context: &InferContext, node: AnyNodeRef, @@ -2030,18 +1990,20 @@ pub(super) fn report_invalid_return_type( return; }; + let settings = + DisplaySettings::from_possibly_ambiguous_type_pair(context.db(), expected_ty, actual_ty); let return_type_span = context.span(return_type_range); let mut diag = builder.into_diagnostic("Return type does not match returned value"); diag.set_primary_message(format_args!( "expected `{expected_ty}`, found `{actual_ty}`", - expected_ty = expected_ty.display(context.db()), - actual_ty = actual_ty.display(context.db()), + expected_ty = expected_ty.display_with(context.db(), settings), + actual_ty = actual_ty.display_with(context.db(), settings), )); diag.annotate( Annotation::secondary(return_type_span).message(format_args!( "Expected `{expected_ty}` because of return type", - expected_ty = expected_ty.display(context.db()), + expected_ty = expected_ty.display_with(context.db(), settings), )), ); } diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index 2c28957510..a54d0e92f7 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -17,8 +17,8 @@ use crate::types::signatures::{CallableSignature, Parameter, Parameters, Signatu use crate::types::tuple::TupleSpec; use crate::types::{ BoundTypeVarInstance, CallableType, IntersectionType, KnownClass, MaterializationKind, - MethodWrapperKind, Protocol, StringLiteralType, SubclassOfInner, Type, UnionType, - WrapperDescriptorKind, + MethodWrapperKind, Protocol, ProtocolInstanceType, StringLiteralType, SubclassOfInner, Type, + UnionType, WrapperDescriptorKind, }; use ruff_db::parsed::parsed_module; @@ -55,6 +55,58 @@ impl DisplaySettings { ..self } } + + #[must_use] + pub fn from_possibly_ambiguous_type_pair<'db>( + db: &'db dyn Db, + type_1: Type<'db>, + type_2: Type<'db>, + ) -> Self { + let result = Self::default(); + + let Some(class_1) = type_to_class_literal(db, type_1) else { + return result; + }; + + let Some(class_2) = type_to_class_literal(db, type_2) else { + return result; + }; + + if class_1 == class_2 { + return result; + } + + if class_1.name(db) == class_2.name(db) { + result.qualified() + } else { + result + } + } +} + +// TODO: generalize this to a method that takes any two types, walks them recursively, and returns +// a set of types with ambiguous names whose display should be qualified. Then we can use this in +// any diagnostic that displays two types. +fn type_to_class_literal<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option> { + match ty { + Type::ClassLiteral(class) => Some(class), + Type::NominalInstance(instance) => { + type_to_class_literal(db, Type::from(instance.class(db))) + } + Type::EnumLiteral(enum_literal) => Some(enum_literal.enum_class(db)), + Type::GenericAlias(alias) => Some(alias.origin(db)), + Type::ProtocolInstance(ProtocolInstanceType { + inner: Protocol::FromClass(class), + .. + }) => type_to_class_literal(db, Type::from(class)), + Type::TypedDict(typed_dict) => { + type_to_class_literal(db, Type::from(typed_dict.defining_class())) + } + Type::SubclassOf(subclass_of) => { + type_to_class_literal(db, Type::from(subclass_of.subclass_of().into_class()?)) + } + _ => None, + } } impl<'db> Type<'db> { @@ -114,18 +166,25 @@ impl fmt::Debug for DisplayType<'_> { } } -/// Writes the string representation of a type, which is the value displayed either as -/// `Literal[]` or `Literal[, ]` for literal types or as `` for -/// non literals -struct DisplayRepresentation<'db> { - ty: Type<'db>, +impl<'db> ClassLiteral<'db> { + fn display_with(self, db: &'db dyn Db, settings: DisplaySettings) -> ClassDisplay<'db> { + ClassDisplay { + db, + class: self, + settings, + } + } +} + +struct ClassDisplay<'db> { db: &'db dyn Db, + class: ClassLiteral<'db>, settings: DisplaySettings, } -impl DisplayRepresentation<'_> { - fn class_parents(&self, class: ClassLiteral) -> Vec { - let body_scope = class.body_scope(self.db); +impl ClassDisplay<'_> { + fn class_parents(&self) -> Vec { + let body_scope = self.class.body_scope(self.db); let file = body_scope.file(self.db); let module_ast = parsed_module(self.db, file).load(self.db); let index = semantic_index(self.db, file); @@ -165,23 +224,29 @@ impl DisplayRepresentation<'_> { name_parts.reverse(); name_parts } +} - fn write_maybe_qualified_class( - &self, - f: &mut Formatter<'_>, - class: ClassLiteral, - ) -> fmt::Result { +impl Display for ClassDisplay<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { if self.settings.qualified { - let parents = self.class_parents(class); - if !parents.is_empty() { - f.write_str(&parents.join("."))?; + for parent in self.class_parents() { + f.write_str(&parent)?; f.write_char('.')?; } } - f.write_str(class.name(self.db)) + f.write_str(self.class.name(self.db)) } } +/// Writes the string representation of a type, which is the value displayed either as +/// `Literal[]` or `Literal[, ]` for literal types or as `` for +/// non literals +struct DisplayRepresentation<'db> { + ty: Type<'db>, + db: &'db dyn Db, + settings: DisplaySettings, +} + impl Display for DisplayRepresentation<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self.ty { @@ -200,14 +265,14 @@ impl Display for DisplayRepresentation<'_> { .display_with(self.db, self.settings) .fmt(f), (ClassType::NonGeneric(class), _) => { - self.write_maybe_qualified_class(f, class) + class.display_with(self.db, self.settings).fmt(f) }, (ClassType::Generic(alias), _) => alias.display_with(self.db, self.settings).fmt(f), } } Type::ProtocolInstance(protocol) => match protocol.inner { Protocol::FromClass(ClassType::NonGeneric(class)) => { - self.write_maybe_qualified_class(f, class) + class.display_with(self.db, self.settings).fmt(f) } Protocol::FromClass(ClassType::Generic(alias)) => { alias.display_with(self.db, self.settings).fmt(f) @@ -231,11 +296,11 @@ impl Display for DisplayRepresentation<'_> { Type::ModuleLiteral(module) => { write!(f, "", module.module(self.db).name(self.db)) } - Type::ClassLiteral(class) => { - write!(f, "") - } + Type::ClassLiteral(class) => write!( + f, + "", + class.display_with(self.db, self.settings) + ), Type::GenericAlias(generic) => write!( f, "", @@ -243,9 +308,7 @@ impl Display for DisplayRepresentation<'_> { ), Type::SubclassOf(subclass_of_ty) => match subclass_of_ty.subclass_of() { SubclassOfInner::Class(ClassType::NonGeneric(class)) => { - write!(f, "type[")?; - self.write_maybe_qualified_class(f, class)?; - write!(f, "]") + write!(f, "type[{}]", class.display_with(self.db, self.settings)) } SubclassOfInner::Class(ClassType::Generic(alias)) => { write!( @@ -320,13 +383,13 @@ impl Display for DisplayRepresentation<'_> { ) } Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet(_)) => { - write!(f, "",) + f.write_str("") } Type::MethodWrapper(MethodWrapperKind::PropertyDunderSet(_)) => { - write!(f, "",) + f.write_str("") } Type::MethodWrapper(MethodWrapperKind::StrStartswith(_)) => { - write!(f, "",) + f.write_str("") } Type::WrapperDescriptor(kind) => { let (method, object) = match kind { @@ -355,11 +418,14 @@ impl Display for DisplayRepresentation<'_> { escape.bytes_repr(TripleQuotes::No).write(f) } - Type::EnumLiteral(enum_literal) => { - self.write_maybe_qualified_class(f, enum_literal.enum_class(self.db))?; - f.write_char('.')?; - f.write_str(enum_literal.name(self.db)) - } + Type::EnumLiteral(enum_literal) => write!( + f, + "{enum_class}.{literal_name}", + enum_class = enum_literal + .enum_class(self.db) + .display_with(self.db, self.settings), + literal_name = enum_literal.name(self.db) + ), Type::NonInferableTypeVar(bound_typevar) | Type::TypeVar(bound_typevar) => { bound_typevar.display(self.db).fmt(f) } @@ -389,10 +455,12 @@ impl Display for DisplayRepresentation<'_> { } f.write_str("]") } - Type::TypedDict(typed_dict) => self.write_maybe_qualified_class( - f, - typed_dict.defining_class().class_literal(self.db).0, - ), + Type::TypedDict(typed_dict) => typed_dict + .defining_class() + .class_literal(self.db) + .0 + .display_with(self.db, self.settings) + .fmt(f), Type::TypeAlias(alias) => f.write_str(alias.name(self.db)), } } @@ -647,7 +715,7 @@ impl Display for DisplayGenericAlias<'_> { f, "{prefix}{origin}{specialization}{suffix}", prefix = prefix, - origin = self.origin.name(self.db), + origin = self.origin.display_with(self.db, self.settings), specialization = self.specialization.display_short( self.db, TupleSpecialization::from_class(self.db, self.origin)