diff --git a/Cargo.lock b/Cargo.lock index 212f965b45..e49fa71d59 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4443,6 +4443,7 @@ dependencies = [ "colored 3.0.0", "insta", "memchr", + "path-slash", "regex", "ruff_db", "ruff_index", diff --git a/crates/ty_python_semantic/resources/mdtest/diagnostics/same_names.md b/crates/ty_python_semantic/resources/mdtest/diagnostics/same_names.md index 6e6dfac60d..d7fff08a69 100644 --- a/crates/ty_python_semantic/resources/mdtest/diagnostics/same_names.md +++ b/crates/ty_python_semantic/resources/mdtest/diagnostics/same_names.md @@ -61,6 +61,35 @@ class DataFrame: pass ``` +## Class from different module with the same qualified name + +`package/__init__.py`: + +```py +from .foo import MyClass + +def make_MyClass() -> MyClass: + return MyClass() +``` + +`package/foo.pyi`: + +```pyi +class MyClass: ... +``` + +`package/foo.py`: + +```py +class MyClass: ... + +def get_MyClass() -> MyClass: + from . import make_MyClass + + # error: [invalid-return-type] "Return type does not match returned value: expected `package.foo.MyClass @ src/package/foo.py:1`, found `package.foo.MyClass @ src/package/foo.pyi:1`" + return make_MyClass() +``` + ## Enum from different modules ```py diff --git a/crates/ty_python_semantic/resources/mdtest/public_types.md b/crates/ty_python_semantic/resources/mdtest/public_types.md index 2b2ccb3ab5..56b6a803c5 100644 --- a/crates/ty_python_semantic/resources/mdtest/public_types.md +++ b/crates/ty_python_semantic/resources/mdtest/public_types.md @@ -339,7 +339,7 @@ class A: ... def f(x: A): # TODO: no error - # error: [invalid-assignment] "Object of type `mdtest_snippet.A | mdtest_snippet.A` is not assignable to `mdtest_snippet.A`" + # error: [invalid-assignment] "Object of type `mdtest_snippet.A @ src/mdtest_snippet.py:12 | mdtest_snippet.A @ src/mdtest_snippet.py:13` is not assignable to `mdtest_snippet.A @ src/mdtest_snippet.py:13`" x = A() ``` diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 896c2730de..566db82389 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1462,7 +1462,7 @@ impl<'db> ClassLiteral<'db> { .map(|generic_context| generic_context.promote_literals(db)) } - fn file(self, db: &dyn Db) -> File { + pub(super) fn file(self, db: &dyn Db) -> File { self.body_scope(db).file(db) } diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index 2c52f40198..1bee509e26 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -1,11 +1,14 @@ //! Display implementations for types. +use std::borrow::Cow; use std::cell::RefCell; use std::collections::hash_map::Entry; use std::fmt::{self, Display, Formatter, Write}; use std::rc::Rc; use ruff_db::display::FormatterJoinExtension; +use ruff_db::files::FilePath; +use ruff_db::source::line_index; use ruff_python_ast::str::{Quote, TripleQuotes}; use ruff_python_literal::escape::AsciiEscape; use ruff_text_size::{TextRange, TextSize}; @@ -34,7 +37,7 @@ pub struct DisplaySettings<'db> { pub multiline: bool, /// Class names that should be displayed fully qualified /// (e.g., `module.ClassName` instead of just `ClassName`) - pub qualified: Rc>, + pub qualified: Rc>, /// Whether long unions are displayed in full pub preserve_full_unions: bool, } @@ -88,7 +91,9 @@ impl<'db> DisplaySettings<'db> { .class_names .borrow() .iter() - .filter_map(|(name, ambiguity)| ambiguity.is_ambiguous().then_some(*name)) + .filter_map(|(name, ambiguity)| { + Some((*name, QualificationLevel::from_ambiguity_state(ambiguity)?)) + }) .collect(), ), ..Self::default() @@ -96,6 +101,22 @@ impl<'db> DisplaySettings<'db> { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QualificationLevel { + ModuleName, + FileAndLineNumber, +} + +impl QualificationLevel { + const fn from_ambiguity_state(state: &AmbiguityState) -> Option { + match state { + AmbiguityState::Unambiguous(_) => None, + AmbiguityState::RequiresFullyQualifiedName { .. } => Some(Self::ModuleName), + AmbiguityState::RequiresFileAndLineNumber => Some(Self::FileAndLineNumber), + } + } +} + #[derive(Debug, Default)] struct AmbiguousClassCollector<'db> { visited_types: RefCell>>, @@ -110,10 +131,32 @@ impl<'db> AmbiguousClassCollector<'db> { } Entry::Occupied(mut entry) => { let value = entry.get_mut(); - if let AmbiguityState::Unambiguous(existing) = value - && *existing != class - { - *value = AmbiguityState::Ambiguous; + match value { + AmbiguityState::Unambiguous(existing) => { + if *existing != class { + let qualified_name_components = class.qualified_name_components(db); + if existing.qualified_name_components(db) == qualified_name_components { + *value = AmbiguityState::RequiresFileAndLineNumber; + } else { + *value = AmbiguityState::RequiresFullyQualifiedName { + class, + qualified_name_components, + }; + } + } + } + AmbiguityState::RequiresFullyQualifiedName { + class: existing, + qualified_name_components, + } => { + if *existing != class { + let new_components = class.qualified_name_components(db); + if *qualified_name_components == new_components { + *value = AmbiguityState::RequiresFileAndLineNumber; + } + } + } + AmbiguityState::RequiresFileAndLineNumber => {} } } } @@ -122,18 +165,18 @@ impl<'db> AmbiguousClassCollector<'db> { /// Whether or not a class can be unambiguously identified by its *unqualified* name /// given the other types that are present in the same context. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] enum AmbiguityState<'db> { /// The class can be displayed unambiguously using its unqualified name Unambiguous(ClassLiteral<'db>), /// The class must be displayed using its fully qualified name to avoid ambiguity. - Ambiguous, -} - -impl AmbiguityState<'_> { - const fn is_ambiguous(self) -> bool { - matches!(self, AmbiguityState::Ambiguous) - } + RequiresFullyQualifiedName { + class: ClassLiteral<'db>, + qualified_name_components: Vec, + }, + /// Even the class's fully qualified name is not sufficient; + /// we must also include the file and line number. + RequiresFileAndLineNumber, } impl<'db> super::visitor::TypeVisitor<'db> for AmbiguousClassCollector<'db> { @@ -232,21 +275,19 @@ impl<'db> ClassLiteral<'db> { settings, } } -} -struct ClassDisplay<'db> { - db: &'db dyn Db, - class: ClassLiteral<'db>, - settings: DisplaySettings<'db>, -} - -impl ClassDisplay<'_> { - fn class_parents(&self) -> Vec { - let body_scope = self.class.body_scope(self.db); - let file = body_scope.file(self.db); - let module_ast = parsed_module(self.db, file).load(self.db); - let index = semantic_index(self.db, file); - let file_scope_id = body_scope.file_scope_id(self.db); + /// Returns the components of the qualified name of this class, excluding this class itself. + /// + /// For example, calling this method on a class `C` in the module `a.b` would return + /// `["a", "b"]`. Calling this method on a class `D` inside the namespace of a method + /// `m` inside the namespace of a class `C` in the module `a.b` would return + /// `["a", "b", "C", ""]`. + fn qualified_name_components(self, db: &'db dyn Db) -> Vec { + let body_scope = self.body_scope(db); + let file = body_scope.file(db); + let module_ast = parsed_module(db, file).load(db); + let index = semantic_index(db, file); + let file_scope_id = body_scope.file_scope_id(db); let mut name_parts = vec![]; @@ -272,8 +313,8 @@ impl ClassDisplay<'_> { } } - if let Some(module) = file_to_module(self.db, file) { - let module_name = module.name(self.db); + if let Some(module) = file_to_module(db, file) { + let module_name = module.name(db); name_parts.push(module_name.as_str().to_string()); } @@ -282,19 +323,39 @@ impl ClassDisplay<'_> { } } +struct ClassDisplay<'db> { + db: &'db dyn Db, + class: ClassLiteral<'db>, + settings: DisplaySettings<'db>, +} + impl Display for ClassDisplay<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - if self - .settings - .qualified - .contains(&**self.class.name(self.db)) - { - for parent in self.class_parents() { + let qualification_level = self.settings.qualified.get(&**self.class.name(self.db)); + if qualification_level.is_some() { + for parent in self.class.qualified_name_components(self.db) { f.write_str(&parent)?; f.write_char('.')?; } } - f.write_str(self.class.name(self.db)) + f.write_str(self.class.name(self.db))?; + if qualification_level == Some(&QualificationLevel::FileAndLineNumber) { + let file = self.class.file(self.db); + let path = file.path(self.db); + let path = match path { + FilePath::System(path) => Cow::Owned(FilePath::System( + path.strip_prefix(self.db.system().current_directory()) + .unwrap_or(path) + .to_path_buf(), + )), + FilePath::Vendored(_) | FilePath::SystemVirtual(_) => Cow::Borrowed(path), + }; + let line_index = line_index(self.db, file); + let class_offset = self.class.header_range(self.db).start(); + let line_number = line_index.line_index(class_offset); + write!(f, " @ {path}:{line_number}")?; + } + Ok(()) } } diff --git a/crates/ty_test/Cargo.toml b/crates/ty_test/Cargo.toml index 97bd4bea2c..670c46cc91 100644 --- a/crates/ty_test/Cargo.toml +++ b/crates/ty_test/Cargo.toml @@ -28,6 +28,7 @@ camino = { workspace = true } colored = { workspace = true } insta = { workspace = true, features = ["filters"] } memchr = { workspace = true } +path-slash ={ workspace = true } regex = { workspace = true } rustc-hash = { workspace = true } rustc-stable-hash = { workspace = true } diff --git a/crates/ty_test/src/matcher.rs b/crates/ty_test/src/matcher.rs index 39fe8633ca..8c1baeff52 100644 --- a/crates/ty_test/src/matcher.rs +++ b/crates/ty_test/src/matcher.rs @@ -4,8 +4,10 @@ use std::borrow::Cow; use std::cmp::Ordering; use std::ops::Range; +use std::sync::LazyLock; use colored::Colorize; +use path_slash::PathExt; use ruff_db::diagnostic::{Diagnostic, DiagnosticId}; use ruff_db::files::File; use ruff_db::source::{SourceText, line_index, source_text}; @@ -201,8 +203,8 @@ impl UnmatchedWithColumn for &Diagnostic { fn discard_todo_metadata(ty: &str) -> Cow<'_, str> { #[cfg(not(debug_assertions))] { - static TODO_METADATA_REGEX: std::sync::LazyLock = - std::sync::LazyLock::new(|| regex::Regex::new(r"@Todo\([^)]*\)").unwrap()); + static TODO_METADATA_REGEX: LazyLock = + LazyLock::new(|| regex::Regex::new(r"@Todo\([^)]*\)").unwrap()); TODO_METADATA_REGEX.replace_all(ty, "@Todo") } @@ -211,6 +213,29 @@ fn discard_todo_metadata(ty: &str) -> Cow<'_, str> { Cow::Borrowed(ty) } +/// Normalize paths in diagnostics to Unix paths before comparing them against +/// the expected type. Doing otherwise means that it's hard to write cross-platform +/// tests, since in some edge cases the display of a type can include a path to the +/// file in which the type was defined (e.g. `foo.bar.A @ src/foo/bar.py:10` on Unix, +/// but `foo.bar.A @ src\foo\bar.py:10` on Windows). +fn normalize_paths(ty: &str) -> Cow<'_, str> { + static PATH_IN_CLASS_DISPLAY_REGEX: LazyLock = + LazyLock::new(|| regex::Regex::new(r"( @ )(.+)(\.pyi?:\d)").unwrap()); + + fn normalize_path_captures(path_captures: ®ex::Captures) -> String { + let normalized_path = std::path::Path::new(&path_captures[2]) + .to_slash() + .expect("Python module paths should be valid UTF-8"); + + format!( + "{}{}{}", + &path_captures[1], normalized_path, &path_captures[3] + ) + } + + PATH_IN_CLASS_DISPLAY_REGEX.replace_all(ty, normalize_path_captures) +} + struct Matcher { line_index: LineIndex, source: SourceText, @@ -294,7 +319,7 @@ impl Matcher { .column .is_none_or(|col| col == self.column(diagnostic)); let message_matches = error.message_contains.is_none_or(|needle| { - diagnostic.concise_message().to_string().contains(needle) + normalize_paths(&diagnostic.concise_message().to_string()).contains(needle) }); lint_name_matches && column_matches && message_matches });