diff --git a/crates/ty_ide/src/doc_highlights.rs b/crates/ty_ide/src/doc_highlights.rs new file mode 100644 index 0000000000..ec6a42d6a1 --- /dev/null +++ b/crates/ty_ide/src/doc_highlights.rs @@ -0,0 +1,237 @@ +use crate::goto::find_goto_target; +use crate::references::{ReferencesMode, references}; +use crate::{Db, ReferenceTarget}; +use ruff_db::files::File; +use ruff_text_size::TextSize; + +/// Find all document highlights for a symbol at the given position. +/// Document highlights are limited to the current file only. +pub fn document_highlights( + db: &dyn Db, + file: File, + offset: TextSize, +) -> Option> { + let parsed = ruff_db::parsed::parsed_module(db, file); + let module = parsed.load(db); + + // Get the definitions for the symbol at the cursor position + let goto_target = find_goto_target(&module, offset)?; + + // Use DocumentHighlights mode which limits search to current file only + references(db, file, &goto_target, ReferencesMode::DocumentHighlights) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{CursorTest, IntoDiagnostic, cursor_test}; + use insta::assert_snapshot; + use ruff_db::diagnostic::{Annotation, Diagnostic, DiagnosticId, LintName, Severity, Span}; + use ruff_db::files::FileRange; + use ruff_text_size::Ranged; + + impl CursorTest { + fn document_highlights(&self) -> String { + let Some(highlight_results) = + document_highlights(&self.db, self.cursor.file, self.cursor.offset) + else { + return "No highlights found".to_string(); + }; + + if highlight_results.is_empty() { + return "No highlights found".to_string(); + } + + self.render_diagnostics(highlight_results.into_iter().enumerate().map( + |(i, highlight_item)| -> HighlightResult { + HighlightResult { + index: i, + file_range: FileRange::new(highlight_item.file(), highlight_item.range()), + kind: highlight_item.kind(), + } + }, + )) + } + } + + struct HighlightResult { + index: usize, + file_range: FileRange, + kind: crate::ReferenceKind, + } + + impl IntoDiagnostic for HighlightResult { + fn into_diagnostic(self) -> Diagnostic { + let kind_str = match self.kind { + crate::ReferenceKind::Read => "Read", + crate::ReferenceKind::Write => "Write", + crate::ReferenceKind::Other => "Other", + }; + let mut main = Diagnostic::new( + DiagnosticId::Lint(LintName::of("document_highlights")), + Severity::Info, + format!("Highlight {} ({})", self.index + 1, kind_str), + ); + main.annotate(Annotation::primary( + Span::from(self.file_range.file()).with_range(self.file_range.range()), + )); + + main + } + } + + #[test] + fn test_local_variable_highlights() { + let test = cursor_test( + " +def calculate_sum(): + value = 10 + doubled = value * 2 + result = value + doubled + return value +", + ); + + assert_snapshot!(test.document_highlights(), @r" + info[document_highlights]: Highlight 1 (Write) + --> main.py:3:5 + | + 2 | def calculate_sum(): + 3 | value = 10 + | ^^^^^ + 4 | doubled = value * 2 + 5 | result = value + doubled + | + + info[document_highlights]: Highlight 2 (Read) + --> main.py:4:15 + | + 2 | def calculate_sum(): + 3 | value = 10 + 4 | doubled = value * 2 + | ^^^^^ + 5 | result = value + doubled + 6 | return value + | + + info[document_highlights]: Highlight 3 (Read) + --> main.py:5:14 + | + 3 | value = 10 + 4 | doubled = value * 2 + 5 | result = value + doubled + | ^^^^^ + 6 | return value + | + + info[document_highlights]: Highlight 4 (Read) + --> main.py:6:12 + | + 4 | doubled = value * 2 + 5 | result = value + doubled + 6 | return value + | ^^^^^ + | + "); + } + + #[test] + fn test_parameter_highlights() { + let test = cursor_test( + " +def process_data(data): + if data: + processed = data.upper() + return processed + return data +", + ); + + assert_snapshot!(test.document_highlights(), @r" + info[document_highlights]: Highlight 1 (Other) + --> main.py:2:18 + | + 2 | def process_data(data): + | ^^^^ + 3 | if data: + 4 | processed = data.upper() + | + + info[document_highlights]: Highlight 2 (Read) + --> main.py:3:8 + | + 2 | def process_data(data): + 3 | if data: + | ^^^^ + 4 | processed = data.upper() + 5 | return processed + | + + info[document_highlights]: Highlight 3 (Read) + --> main.py:4:21 + | + 2 | def process_data(data): + 3 | if data: + 4 | processed = data.upper() + | ^^^^ + 5 | return processed + 6 | return data + | + + info[document_highlights]: Highlight 4 (Read) + --> main.py:6:12 + | + 4 | processed = data.upper() + 5 | return processed + 6 | return data + | ^^^^ + | + "); + } + + #[test] + fn test_class_name_highlights() { + let test = cursor_test( + " +class Calculator: + def __init__(self): + self.name = 'Calculator' + +calc = Calculator() +", + ); + + assert_snapshot!(test.document_highlights(), @r" + info[document_highlights]: Highlight 1 (Other) + --> main.py:2:7 + | + 2 | class Calculator: + | ^^^^^^^^^^ + 3 | def __init__(self): + 4 | self.name = 'Calculator' + | + + info[document_highlights]: Highlight 2 (Read) + --> main.py:6:8 + | + 4 | self.name = 'Calculator' + 5 | + 6 | calc = Calculator() + | ^^^^^^^^^^ + | + "); + } + + #[test] + fn test_no_highlights_for_unknown_symbol() { + let test = cursor_test( + " +def test(): + # Cursor on a position with no symbol + +", + ); + + assert_snapshot!(test.document_highlights(), @"No highlights found"); + } +} diff --git a/crates/ty_ide/src/goto_references.rs b/crates/ty_ide/src/goto_references.rs new file mode 100644 index 0000000000..1f6d5063f5 --- /dev/null +++ b/crates/ty_ide/src/goto_references.rs @@ -0,0 +1,856 @@ +use crate::goto::find_goto_target; +use crate::references::{ReferencesMode, references}; +use crate::{Db, ReferenceTarget}; +use ruff_db::files::File; +use ruff_text_size::TextSize; + +/// Find all references to a symbol at the given position. +/// Search for references across all files in the project. +pub fn goto_references( + db: &dyn Db, + file: File, + offset: TextSize, + include_declaration: bool, +) -> Option> { + let parsed = ruff_db::parsed::parsed_module(db, file); + let module = parsed.load(db); + + // Get the definitions for the symbol at the cursor position + let goto_target = find_goto_target(&module, offset)?; + + let mode = if include_declaration { + ReferencesMode::References + } else { + ReferencesMode::ReferencesSkipDeclaration + }; + + references(db, file, &goto_target, mode) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{CursorTest, IntoDiagnostic, cursor_test}; + use insta::assert_snapshot; + use ruff_db::diagnostic::{Annotation, Diagnostic, DiagnosticId, LintName, Severity, Span}; + use ruff_db::files::FileRange; + use ruff_text_size::Ranged; + + impl CursorTest { + fn references(&self) -> String { + let Some(reference_results) = + goto_references(&self.db, self.cursor.file, self.cursor.offset, true) + else { + return "No references found".to_string(); + }; + + if reference_results.is_empty() { + return "No references found".to_string(); + } + + self.render_diagnostics(reference_results.into_iter().enumerate().map( + |(i, ref_item)| -> ReferenceResult { + ReferenceResult { + index: i, + file_range: FileRange::new(ref_item.file(), ref_item.range()), + } + }, + )) + } + } + + struct ReferenceResult { + index: usize, + file_range: FileRange, + } + + impl IntoDiagnostic for ReferenceResult { + fn into_diagnostic(self) -> Diagnostic { + let mut main = Diagnostic::new( + DiagnosticId::Lint(LintName::of("references")), + Severity::Info, + format!("Reference {}", self.index + 1), + ); + main.annotate(Annotation::primary( + Span::from(self.file_range.file()).with_range(self.file_range.range()), + )); + + main + } + } + + #[test] + fn test_parameter_references_in_function() { + let test = cursor_test( + " +def calculate_sum(value: int) -> int: + doubled = value * 2 + result = value + doubled + return value + +# Call with keyword argument +result = calculate_sum(value=42) +", + ); + + assert_snapshot!(test.references(), @r###" + info[references]: Reference 1 + --> main.py:2:19 + | + 2 | def calculate_sum(value: int) -> int: + | ^^^^^ + 3 | doubled = value * 2 + 4 | result = value + doubled + | + + info[references]: Reference 2 + --> main.py:3:15 + | + 2 | def calculate_sum(value: int) -> int: + 3 | doubled = value * 2 + | ^^^^^ + 4 | result = value + doubled + 5 | return value + | + + info[references]: Reference 3 + --> main.py:4:14 + | + 2 | def calculate_sum(value: int) -> int: + 3 | doubled = value * 2 + 4 | result = value + doubled + | ^^^^^ + 5 | return value + | + + info[references]: Reference 4 + --> main.py:5:12 + | + 3 | doubled = value * 2 + 4 | result = value + doubled + 5 | return value + | ^^^^^ + 6 | + 7 | # Call with keyword argument + | + + info[references]: Reference 5 + --> main.py:8:24 + | + 7 | # Call with keyword argument + 8 | result = calculate_sum(value=42) + | ^^^^^ + | + "###); + } + + #[test] + #[ignore] // TODO: Enable when nonlocal support is fully implemented in goto.rs + fn test_nonlocal_variable_references() { + let test = cursor_test( + " +def outer_function(): + counter = 0 + + def increment(): + nonlocal counter + counter += 1 + return counter + + def decrement(): + nonlocal counter + counter -= 1 + return counter + + # Use counter in outer scope + initial = counter + increment() + decrement() + final = counter + + return increment, decrement +", + ); + + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> main.py:3:5 + | + 2 | def outer_function(): + 3 | counter = 0 + | ^^^^^^^ + 4 | + 5 | def increment(): + | + + info[references]: Reference 2 + --> main.py:6:18 + | + 5 | def increment(): + 6 | nonlocal counter + | ^^^^^^^ + 7 | counter += 1 + 8 | return counter + | + + info[references]: Reference 3 + --> main.py:7:9 + | + 5 | def increment(): + 6 | nonlocal counter + 7 | counter += 1 + | ^^^^^^^ + 8 | return counter + | + + info[references]: Reference 4 + --> main.py:8:16 + | + 6 | nonlocal counter + 7 | counter += 1 + 8 | return counter + | ^^^^^^^ + 9 | + 10 | def decrement(): + | + + info[references]: Reference 5 + --> main.py:11:18 + | + 10 | def decrement(): + 11 | nonlocal counter + | ^^^^^^^ + 12 | counter -= 1 + 13 | return counter + | + + info[references]: Reference 6 + --> main.py:12:9 + | + 10 | def decrement(): + 11 | nonlocal counter + 12 | counter -= 1 + | ^^^^^^^ + 13 | return counter + | + + info[references]: Reference 7 + --> main.py:13:16 + | + 11 | nonlocal counter + 12 | counter -= 1 + 13 | return counter + | ^^^^^^^ + 14 | + 15 | # Use counter in outer scope + | + + info[references]: Reference 8 + --> main.py:16:15 + | + 15 | # Use counter in outer scope + 16 | initial = counter + | ^^^^^^^ + 17 | increment() + 18 | decrement() + | + + info[references]: Reference 9 + --> main.py:19:13 + | + 17 | increment() + 18 | decrement() + 19 | final = counter + | ^^^^^^^ + 20 | + 21 | return increment, decrement + | + "); + } + + #[test] + #[ignore] // TODO: Enable when global support is fully implemented in goto.rs + fn test_global_variable_references() { + let test = cursor_test( + " +global_counter = 0 + +def increment_global(): + global global_counter + global_counter += 1 + return global_counter + +def decrement_global(): + global global_counter + global_counter -= 1 + return global_counter + +# Use global_counter at module level +initial_value = global_counter +increment_global() +decrement_global() +final_value = global_counter +", + ); + + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> main.py:2:1 + | + 2 | global_counter = 0 + | ^^^^^^^^^^^^^^ + 3 | + 4 | def increment_global(): + | + + info[references]: Reference 2 + --> main.py:5:12 + | + 4 | def increment_global(): + 5 | global global_counter + | ^^^^^^^^^^^^^^ + 6 | global_counter += 1 + 7 | return global_counter + | + + info[references]: Reference 3 + --> main.py:6:5 + | + 4 | def increment_global(): + 5 | global global_counter + 6 | global_counter += 1 + | ^^^^^^^^^^^^^^ + 7 | return global_counter + | + + info[references]: Reference 4 + --> main.py:7:12 + | + 5 | global global_counter + 6 | global_counter += 1 + 7 | return global_counter + | ^^^^^^^^^^^^^^ + 8 | + 9 | def decrement_global(): + | + + info[references]: Reference 5 + --> main.py:10:12 + | + 9 | def decrement_global(): + 10 | global global_counter + | ^^^^^^^^^^^^^^ + 11 | global_counter -= 1 + 12 | return global_counter + | + + info[references]: Reference 6 + --> main.py:11:5 + | + 9 | def decrement_global(): + 10 | global global_counter + 11 | global_counter -= 1 + | ^^^^^^^^^^^^^^ + 12 | return global_counter + | + + info[references]: Reference 7 + --> main.py:12:12 + | + 10 | global global_counter + 11 | global_counter -= 1 + 12 | return global_counter + | ^^^^^^^^^^^^^^ + 13 | + 14 | # Use global_counter at module level + | + + info[references]: Reference 8 + --> main.py:15:17 + | + 14 | # Use global_counter at module level + 15 | initial_value = global_counter + | ^^^^^^^^^^^^^^ + 16 | increment_global() + 17 | decrement_global() + | + + info[references]: Reference 9 + --> main.py:18:15 + | + 16 | increment_global() + 17 | decrement_global() + 18 | final_value = global_counter + | ^^^^^^^^^^^^^^ + | + "); + } + + #[test] + fn test_except_handler_variable_references() { + let test = cursor_test( + " +try: + x = 1 / 0 +except ZeroDivisionError as err: + print(f'Error: {err}') + return err + +try: + y = 2 / 0 +except ValueError as err: + print(f'Different error: {err}') +", + ); + + // Note: Currently only finds the declaration, not the usages + // This is because semantic analysis for except handler variables isn't fully implemented + assert_snapshot!(test.references(), @r###" + info[references]: Reference 1 + --> main.py:4:29 + | + 2 | try: + 3 | x = 1 / 0 + 4 | except ZeroDivisionError as err: + | ^^^ + 5 | print(f'Error: {err}') + 6 | return err + | + "###); + } + + #[test] + fn test_pattern_match_as_references() { + let test = cursor_test( + " +match x: + case [a, b] as pattern: + print(f'Matched: {pattern}') + return pattern + case _: + pass +", + ); + + assert_snapshot!(test.references(), @r###" + info[references]: Reference 1 + --> main.py:3:20 + | + 2 | match x: + 3 | case [a, b] as pattern: + | ^^^^^^^ + 4 | print(f'Matched: {pattern}') + 5 | return pattern + | + + info[references]: Reference 2 + --> main.py:4:27 + | + 2 | match x: + 3 | case [a, b] as pattern: + 4 | print(f'Matched: {pattern}') + | ^^^^^^^ + 5 | return pattern + 6 | case _: + | + + info[references]: Reference 3 + --> main.py:5:16 + | + 3 | case [a, b] as pattern: + 4 | print(f'Matched: {pattern}') + 5 | return pattern + | ^^^^^^^ + 6 | case _: + 7 | pass + | + "###); + } + + #[test] + fn test_pattern_match_mapping_rest_references() { + let test = cursor_test( + " +match data: + case {'a': a, 'b': b, **rest}: + print(f'Rest data: {rest}') + process(rest) + return rest +", + ); + + assert_snapshot!(test.references(), @r###" + info[references]: Reference 1 + --> main.py:3:29 + | + 2 | match data: + 3 | case {'a': a, 'b': b, **rest}: + | ^^^^ + 4 | print(f'Rest data: {rest}') + 5 | process(rest) + | + + info[references]: Reference 2 + --> main.py:4:29 + | + 2 | match data: + 3 | case {'a': a, 'b': b, **rest}: + 4 | print(f'Rest data: {rest}') + | ^^^^ + 5 | process(rest) + 6 | return rest + | + + info[references]: Reference 3 + --> main.py:5:17 + | + 3 | case {'a': a, 'b': b, **rest}: + 4 | print(f'Rest data: {rest}') + 5 | process(rest) + | ^^^^ + 6 | return rest + | + + info[references]: Reference 4 + --> main.py:6:16 + | + 4 | print(f'Rest data: {rest}') + 5 | process(rest) + 6 | return rest + | ^^^^ + | + "###); + } + + #[test] + fn test_function_definition_references() { + let test = cursor_test( + " +def my_function(): + return 42 + +# Call the function multiple times +result1 = my_function() +result2 = my_function() + +# Function passed as an argument +callback = my_function + +# Function used in different contexts +print(my_function()) +value = my_function +", + ); + + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> main.py:2:5 + | + 2 | def my_function(): + | ^^^^^^^^^^^ + 3 | return 42 + | + + info[references]: Reference 2 + --> main.py:6:11 + | + 5 | # Call the function multiple times + 6 | result1 = my_function() + | ^^^^^^^^^^^ + 7 | result2 = my_function() + | + + info[references]: Reference 3 + --> main.py:7:11 + | + 5 | # Call the function multiple times + 6 | result1 = my_function() + 7 | result2 = my_function() + | ^^^^^^^^^^^ + 8 | + 9 | # Function passed as an argument + | + + info[references]: Reference 4 + --> main.py:10:12 + | + 9 | # Function passed as an argument + 10 | callback = my_function + | ^^^^^^^^^^^ + 11 | + 12 | # Function used in different contexts + | + + info[references]: Reference 5 + --> main.py:13:7 + | + 12 | # Function used in different contexts + 13 | print(my_function()) + | ^^^^^^^^^^^ + 14 | value = my_function + | + + info[references]: Reference 6 + --> main.py:14:9 + | + 12 | # Function used in different contexts + 13 | print(my_function()) + 14 | value = my_function + | ^^^^^^^^^^^ + | + "); + } + + #[test] + fn test_class_definition_references() { + let test = cursor_test( + " +class MyClass: + def __init__(self): + pass + +# Create instances +obj1 = MyClass() +obj2 = MyClass() + +# Use in type annotations +def process(instance: MyClass) -> MyClass: + return instance + +# Reference the class itself +cls = MyClass +", + ); + + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> main.py:2:7 + | + 2 | class MyClass: + | ^^^^^^^ + 3 | def __init__(self): + 4 | pass + | + + info[references]: Reference 2 + --> main.py:7:8 + | + 6 | # Create instances + 7 | obj1 = MyClass() + | ^^^^^^^ + 8 | obj2 = MyClass() + | + + info[references]: Reference 3 + --> main.py:8:8 + | + 6 | # Create instances + 7 | obj1 = MyClass() + 8 | obj2 = MyClass() + | ^^^^^^^ + 9 | + 10 | # Use in type annotations + | + + info[references]: Reference 4 + --> main.py:11:23 + | + 10 | # Use in type annotations + 11 | def process(instance: MyClass) -> MyClass: + | ^^^^^^^ + 12 | return instance + | + + info[references]: Reference 5 + --> main.py:11:35 + | + 10 | # Use in type annotations + 11 | def process(instance: MyClass) -> MyClass: + | ^^^^^^^ + 12 | return instance + | + + info[references]: Reference 6 + --> main.py:15:7 + | + 14 | # Reference the class itself + 15 | cls = MyClass + | ^^^^^^^ + | + "); + } + + #[test] + fn test_multi_file_function_references() { + let test = CursorTest::builder() + .source( + "utils.py", + " +def helper_function(x): + return x * 2 +", + ) + .source( + "module.py", + " +from utils import helper_function + +def process_data(data): + return helper_function(data) + +def double_process(data): + result = helper_function(data) + return helper_function(result) +", + ) + .source( + "app.py", + " +from utils import helper_function + +class DataProcessor: + def __init__(self): + self.multiplier = helper_function + + def process(self, value): + return helper_function(value) +", + ) + .build(); + + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> utils.py:2:5 + | + 2 | def helper_function(x): + | ^^^^^^^^^^^^^^^ + 3 | return x * 2 + | + + info[references]: Reference 2 + --> module.py:5:12 + | + 4 | def process_data(data): + 5 | return helper_function(data) + | ^^^^^^^^^^^^^^^ + 6 | + 7 | def double_process(data): + | + + info[references]: Reference 3 + --> module.py:8:14 + | + 7 | def double_process(data): + 8 | result = helper_function(data) + | ^^^^^^^^^^^^^^^ + 9 | return helper_function(result) + | + + info[references]: Reference 4 + --> module.py:9:12 + | + 7 | def double_process(data): + 8 | result = helper_function(data) + 9 | return helper_function(result) + | ^^^^^^^^^^^^^^^ + | + + info[references]: Reference 5 + --> app.py:6:27 + | + 4 | class DataProcessor: + 5 | def __init__(self): + 6 | self.multiplier = helper_function + | ^^^^^^^^^^^^^^^ + 7 | + 8 | def process(self, value): + | + + info[references]: Reference 6 + --> app.py:9:16 + | + 8 | def process(self, value): + 9 | return helper_function(value) + | ^^^^^^^^^^^^^^^ + | + "); + } + + #[test] + fn test_multi_file_class_attribute_references() { + let test = CursorTest::builder() + .source( + "models.py", + " +class MyModel: + attr = 42 + + def get_attribute(self): + return MyModel.attr +", + ) + .source( + "main.py", + " +from models import MyModel + +def process_model(): + model = MyModel() + value = model.attr + model.attr = 100 + return model.attr +", + ) + .build(); + + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> models.py:3:5 + | + 2 | class MyModel: + 3 | attr = 42 + | ^^^^ + 4 | + 5 | def get_attribute(self): + | + + info[references]: Reference 2 + --> models.py:6:24 + | + 5 | def get_attribute(self): + 6 | return MyModel.attr + | ^^^^ + | + + info[references]: Reference 3 + --> main.py:6:19 + | + 4 | def process_model(): + 5 | model = MyModel() + 6 | value = model.attr + | ^^^^ + 7 | model.attr = 100 + 8 | return model.attr + | + + info[references]: Reference 4 + --> main.py:7:11 + | + 5 | model = MyModel() + 6 | value = model.attr + 7 | model.attr = 100 + | ^^^^ + 8 | return model.attr + | + + info[references]: Reference 5 + --> main.py:8:18 + | + 6 | value = model.attr + 7 | model.attr = 100 + 8 | return model.attr + | ^^^^ + | + "); + } +} diff --git a/crates/ty_ide/src/lib.rs b/crates/ty_ide/src/lib.rs index b171a3acec..32a86c8caa 100644 --- a/crates/ty_ide/src/lib.rs +++ b/crates/ty_ide/src/lib.rs @@ -1,9 +1,11 @@ mod completion; +mod doc_highlights; mod docstring; mod find_node; mod goto; mod goto_declaration; mod goto_definition; +mod goto_references; mod goto_type_definition; mod hover; mod inlay_hints; @@ -14,12 +16,14 @@ mod signature_help; mod stub_mapping; pub use completion::completion; +pub use doc_highlights::document_highlights; pub use docstring::get_parameter_documentation; pub use goto::{goto_declaration, goto_definition, goto_type_definition}; +pub use goto_references::goto_references; pub use hover::hover; pub use inlay_hints::inlay_hints; pub use markup::MarkupKind; -pub use references::references; +pub use references::ReferencesMode; pub use semantic_tokens::{ SemanticToken, SemanticTokenModifier, SemanticTokenType, SemanticTokens, semantic_tokens, }; @@ -110,6 +114,53 @@ impl NavigationTarget { } } +/// Specifies the kind of reference operation. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ReferenceKind { + /// A read reference to a symbol (e.g., using a variable's value) + Read, + /// A write reference to a symbol (e.g., assigning to a variable) + Write, + /// Neither a read or a write (e.g., a function or class declaration) + Other, +} + +/// Target of a reference with information about the kind of operation. +/// Unlike `NavigationTarget`, this type is specifically designed for references +/// and contains only a single range (not separate focus/full ranges) and +/// includes information about whether the reference is a read or write operation. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ReferenceTarget { + file_range: FileRange, + kind: ReferenceKind, +} + +impl ReferenceTarget { + /// Creates a new `ReferenceTarget`. + pub fn new(file: File, range: TextRange, kind: ReferenceKind) -> Self { + Self { + file_range: FileRange::new(file, range), + kind, + } + } + + pub fn file(&self) -> File { + self.file_range.file() + } + + pub fn range(&self) -> TextRange { + self.file_range.range() + } + + pub fn file_range(&self) -> FileRange { + self.file_range + } + + pub fn kind(&self) -> ReferenceKind { + self.kind + } +} + #[derive(Debug, Clone)] pub struct NavigationTargets(smallvec::SmallVec<[NavigationTarget; 1]>); diff --git a/crates/ty_ide/src/references.rs b/crates/ty_ide/src/references.rs index b1a713963f..c84bf94a72 100644 --- a/crates/ty_ide/src/references.rs +++ b/crates/ty_ide/src/references.rs @@ -1,7 +1,8 @@ -//! This module implements the core functionality of the "references" and -//! "rename" language server features. It locates all references to a named -//! symbol. Unlike a simple text search for the symbol's name, this is -//! a "semantic search" where the text and the semantic meaning must match. +//! This module implements the core functionality of the "references", +//! "document highlight" and "rename" language server features. It locates +//! all references to a named symbol. Unlike a simple text search for the +//! symbol's name, this is a "semantic search" where the text and the semantic +//! meaning must match. //! //! Some symbols (such as parameters and local variables) are visible only //! within their scope. All other symbols, such as those defined at the global @@ -10,29 +11,39 @@ //! an expensive search of all source files in the workspace. use crate::find_node::CoveringNode; -use crate::goto::{GotoTarget, find_goto_target}; -use crate::{Db, NavigationTarget, NavigationTargets, RangedValue}; -use ruff_db::files::{File, FileRange}; +use crate::goto::GotoTarget; +use crate::{Db, NavigationTarget, ReferenceKind, ReferenceTarget}; +use ruff_db::files::File; use ruff_python_ast::{ self as ast, AnyNodeRef, visitor::source_order::{SourceOrderVisitor, TraversalSignal}, }; -use ruff_text_size::{Ranged, TextSize}; +use ruff_text_size::{Ranged, TextRange}; + +/// Mode for references search behavior +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReferencesMode { + /// Find all references including the declaration + References, + /// Find all references but skip the declaration + ReferencesSkipDeclaration, + /// Find references for rename operations (behavior differs for imported symbols) + Rename, + /// Find references for document highlights (limits search to current file) + DocumentHighlights, +} /// Find all references to a symbol at the given position. /// Search for references across all files in the project. -pub fn references( +pub(crate) fn references( db: &dyn Db, file: File, - offset: TextSize, - include_declaration: bool, -) -> Option>> { - let parsed = ruff_db::parsed::parsed_module(db, file); - let module = parsed.load(db); - + goto_target: &GotoTarget, + mode: ReferencesMode, +) -> Option> { // Get the definitions for the symbol at the cursor position - let goto_target = find_goto_target(&module, offset)?; - let target_definitions = goto_target.get_definition_targets(file, db, None)?; + let target_definitions_nav = goto_target.get_definition_targets(file, db, None)?; + let target_definitions: Vec = target_definitions_nav.into_iter().collect(); // Extract the target text from the goto target for fast comparison let target_text = goto_target.to_string()?; @@ -43,13 +54,16 @@ pub fn references( db, file, &target_definitions, - include_declaration, &target_text, + mode, &mut references, ); + // Check if we should search across files based on the mode + let search_across_files = !matches!(mode, ReferencesMode::DocumentHighlights); + // Check if the symbol is potentially visible outside of this module - if is_symbol_externally_visible(&goto_target) { + if search_across_files && is_symbol_externally_visible(goto_target) { // Look for references in all other files within the workspace for other_file in &db.project().files(db) { // Skip the current file as we already processed it @@ -68,8 +82,8 @@ pub fn references( db, other_file, &target_definitions, - false, // Don't include declarations from other files &target_text, + mode, &mut references, ); } @@ -82,16 +96,15 @@ pub fn references( } } -/// Find all references to a local symbol within the current file. If -/// `include_declaration` is true, return the original declaration for symbols -/// such as functions or classes that have a single declaration location. +/// Find all references to a local symbol within the current file. +/// The behavior depends on the provided mode. fn references_for_file( db: &dyn Db, file: File, - target_definitions: &NavigationTargets, - include_declaration: bool, + target_definitions: &[NavigationTarget], target_text: &str, - references: &mut Vec>, + mode: ReferencesMode, + references: &mut Vec, ) { let parsed = ruff_db::parsed::parsed_module(db, file); let module = parsed.load(db); @@ -101,7 +114,7 @@ fn references_for_file( file, target_definitions, references, - include_declaration, + mode, target_text, ancestors: Vec::new(), }; @@ -131,9 +144,9 @@ fn is_symbol_externally_visible(goto_target: &GotoTarget<'_>) -> bool { struct LocalReferencesFinder<'a> { db: &'a dyn Db, file: File, - target_definitions: &'a NavigationTargets, - references: &'a mut Vec>, - include_declaration: bool, + target_definitions: &'a [NavigationTarget], + references: &'a mut Vec, + mode: ReferencesMode, target_text: &'a str, ancestors: Vec>, } @@ -155,13 +168,13 @@ impl<'a> SourceOrderVisitor<'a> for LocalReferencesFinder<'a> { AnyNodeRef::ExprAttribute(attr_expr) => { self.check_identifier_reference(&attr_expr.attr); } - AnyNodeRef::StmtFunctionDef(func) if self.include_declaration => { + AnyNodeRef::StmtFunctionDef(func) if self.should_include_declaration() => { self.check_identifier_reference(&func.name); } - AnyNodeRef::StmtClassDef(class) if self.include_declaration => { + AnyNodeRef::StmtClassDef(class) if self.should_include_declaration() => { self.check_identifier_reference(&class.name); } - AnyNodeRef::Parameter(parameter) if self.include_declaration => { + AnyNodeRef::Parameter(parameter) if self.should_include_declaration() => { self.check_identifier_reference(¶meter.name); } AnyNodeRef::Keyword(keyword) => { @@ -169,27 +182,31 @@ impl<'a> SourceOrderVisitor<'a> for LocalReferencesFinder<'a> { self.check_identifier_reference(arg); } } - AnyNodeRef::StmtGlobal(global_stmt) if self.include_declaration => { + AnyNodeRef::StmtGlobal(global_stmt) if self.should_include_declaration() => { for name in &global_stmt.names { self.check_identifier_reference(name); } } - AnyNodeRef::StmtNonlocal(nonlocal_stmt) if self.include_declaration => { + AnyNodeRef::StmtNonlocal(nonlocal_stmt) if self.should_include_declaration() => { for name in &nonlocal_stmt.names { self.check_identifier_reference(name); } } - AnyNodeRef::ExceptHandlerExceptHandler(handler) if self.include_declaration => { + AnyNodeRef::ExceptHandlerExceptHandler(handler) + if self.should_include_declaration() => + { if let Some(name) = &handler.name { self.check_identifier_reference(name); } } - AnyNodeRef::PatternMatchAs(pattern_as) if self.include_declaration => { + AnyNodeRef::PatternMatchAs(pattern_as) if self.should_include_declaration() => { if let Some(name) = &pattern_as.name { self.check_identifier_reference(name); } } - AnyNodeRef::PatternMatchMapping(pattern_mapping) if self.include_declaration => { + AnyNodeRef::PatternMatchMapping(pattern_mapping) + if self.should_include_declaration() => + { if let Some(rest_name) = &pattern_mapping.rest { self.check_identifier_reference(rest_name); } @@ -207,6 +224,16 @@ impl<'a> SourceOrderVisitor<'a> for LocalReferencesFinder<'a> { } impl LocalReferencesFinder<'_> { + /// Check if we should include declarations based on the current mode + fn should_include_declaration(&self) -> bool { + matches!( + self.mode, + ReferencesMode::References + | ReferencesMode::DocumentHighlights + | ReferencesMode::Rename + ) + } + /// Helper method to check identifier references for declarations fn check_identifier_reference(&mut self, identifier: &ast::Identifier) { // Quick text-based check first @@ -237,23 +264,24 @@ impl LocalReferencesFinder<'_> { let range = covering_node.node().range(); // Get the definitions for this goto target - if let Some(current_definitions) = + if let Some(current_definitions_nav) = goto_target.get_definition_targets(self.file, self.db, None) { + let current_definitions: Vec = + current_definitions_nav.into_iter().collect(); // Check if any of the current definitions match our target definitions if self.navigation_targets_match(¤t_definitions) { - let target = NavigationTarget::new(self.file, range); - self.references.push(RangedValue { - value: NavigationTargets::single(target), - range: FileRange::new(self.file, range), - }); + // Determine if this is a read or write reference + let kind = self.determine_reference_kind(covering_node); + let target = ReferenceTarget::new(self.file, range, kind); + self.references.push(target); } } } } - /// Check if `NavigationTargets` match our target definitions - fn navigation_targets_match(&self, current_targets: &NavigationTargets) -> bool { + /// Check if `Vec` match our target definitions + fn navigation_targets_match(&self, current_targets: &[NavigationTarget]) -> bool { // Since we're comparing the same symbol, all definitions should be equivalent // We only need to check against the first target definition if let Some(first_target) = self.target_definitions.iter().next() { @@ -267,831 +295,105 @@ impl LocalReferencesFinder<'_> { } false } -} -#[cfg(test)] -mod tests { - use super::*; - use crate::tests::{CursorTest, IntoDiagnostic, cursor_test}; - use insta::assert_snapshot; - use ruff_db::diagnostic::{Annotation, Diagnostic, DiagnosticId, LintName, Severity, Span}; - use ruff_text_size::Ranged; + /// Determine whether a reference is a read or write operation based on its context + fn determine_reference_kind(&self, covering_node: &CoveringNode<'_>) -> ReferenceKind { + // Reference kind is only meaningful for DocumentHighlights mode + if !matches!(self.mode, ReferencesMode::DocumentHighlights) { + return ReferenceKind::Other; + } - impl CursorTest { - fn references(&self) -> String { - let Some(reference_results) = - references(&self.db, self.cursor.file, self.cursor.offset, true) - else { - return "No references found".to_string(); - }; - - if reference_results.is_empty() { - return "No references found".to_string(); - } - - self.render_diagnostics(reference_results.into_iter().enumerate().map( - |(i, ref_item)| -> ReferenceResult { - ReferenceResult { - index: i, - file_range: ref_item.range, + // Walk up the ancestors to find the context + for ancestor in self.ancestors.iter().rev() { + match ancestor { + // Assignment targets are writes + AnyNodeRef::StmtAssign(assign) => { + // Check if our node is in the targets (left side) of assignment + for target in &assign.targets { + if Self::expr_contains_range(target, covering_node.node().range()) { + return ReferenceKind::Write; + } } - }, - )) + } + AnyNodeRef::StmtAnnAssign(ann_assign) => { + // Check if our node is the target (left side) of annotated assignment + if Self::expr_contains_range(&ann_assign.target, covering_node.node().range()) { + return ReferenceKind::Write; + } + } + AnyNodeRef::StmtAugAssign(aug_assign) => { + // Check if our node is the target (left side) of augmented assignment + if Self::expr_contains_range(&aug_assign.target, covering_node.node().range()) { + return ReferenceKind::Write; + } + } + // For loop targets are writes + AnyNodeRef::StmtFor(for_stmt) => { + if Self::expr_contains_range(&for_stmt.target, covering_node.node().range()) { + return ReferenceKind::Write; + } + } + // With statement targets are writes + AnyNodeRef::WithItem(with_item) => { + if let Some(optional_vars) = &with_item.optional_vars { + if Self::expr_contains_range(optional_vars, covering_node.node().range()) { + return ReferenceKind::Write; + } + } + } + // Exception handler names are writes + AnyNodeRef::ExceptHandlerExceptHandler(handler) => { + if let Some(name) = &handler.name { + if Self::node_contains_range( + AnyNodeRef::from(name), + covering_node.node().range(), + ) { + return ReferenceKind::Write; + } + } + } + AnyNodeRef::StmtFunctionDef(func) => { + if Self::node_contains_range( + AnyNodeRef::from(&func.name), + covering_node.node().range(), + ) { + return ReferenceKind::Other; + } + } + AnyNodeRef::StmtClassDef(class) => { + if Self::node_contains_range( + AnyNodeRef::from(&class.name), + covering_node.node().range(), + ) { + return ReferenceKind::Other; + } + } + AnyNodeRef::Parameter(param) => { + if Self::node_contains_range( + AnyNodeRef::from(¶m.name), + covering_node.node().range(), + ) { + return ReferenceKind::Other; + } + } + AnyNodeRef::StmtGlobal(_) | AnyNodeRef::StmtNonlocal(_) => { + return ReferenceKind::Other; + } + _ => {} + } } + + // Default to read + ReferenceKind::Read } - struct ReferenceResult { - index: usize, - file_range: FileRange, + /// Helper to check if a node contains a given range + fn node_contains_range(node: AnyNodeRef<'_>, range: TextRange) -> bool { + node.range().contains_range(range) } - impl IntoDiagnostic for ReferenceResult { - fn into_diagnostic(self) -> Diagnostic { - let mut main = Diagnostic::new( - DiagnosticId::Lint(LintName::of("references")), - Severity::Info, - format!("Reference {}", self.index + 1), - ); - main.annotate(Annotation::primary( - Span::from(self.file_range.file()).with_range(self.file_range.range()), - )); - - main - } - } - - #[test] - fn test_parameter_references_in_function() { - let test = cursor_test( - " -def calculate_sum(value: int) -> int: - doubled = value * 2 - result = value + doubled - return value - -# Call with keyword argument -result = calculate_sum(value=42) -", - ); - - assert_snapshot!(test.references(), @r###" - info[references]: Reference 1 - --> main.py:2:19 - | - 2 | def calculate_sum(value: int) -> int: - | ^^^^^ - 3 | doubled = value * 2 - 4 | result = value + doubled - | - - info[references]: Reference 2 - --> main.py:3:15 - | - 2 | def calculate_sum(value: int) -> int: - 3 | doubled = value * 2 - | ^^^^^ - 4 | result = value + doubled - 5 | return value - | - - info[references]: Reference 3 - --> main.py:4:14 - | - 2 | def calculate_sum(value: int) -> int: - 3 | doubled = value * 2 - 4 | result = value + doubled - | ^^^^^ - 5 | return value - | - - info[references]: Reference 4 - --> main.py:5:12 - | - 3 | doubled = value * 2 - 4 | result = value + doubled - 5 | return value - | ^^^^^ - 6 | - 7 | # Call with keyword argument - | - - info[references]: Reference 5 - --> main.py:8:24 - | - 7 | # Call with keyword argument - 8 | result = calculate_sum(value=42) - | ^^^^^ - | - "###); - } - - #[test] - #[ignore] // TODO: Enable when nonlocal support is fully implemented in goto.rs - fn test_nonlocal_variable_references() { - let test = cursor_test( - " -def outer_function(): - counter = 0 - - def increment(): - nonlocal counter - counter += 1 - return counter - - def decrement(): - nonlocal counter - counter -= 1 - return counter - - # Use counter in outer scope - initial = counter - increment() - decrement() - final = counter - - return increment, decrement -", - ); - - assert_snapshot!(test.references(), @r" - info[references]: Reference 1 - --> main.py:3:5 - | - 2 | def outer_function(): - 3 | counter = 0 - | ^^^^^^^ - 4 | - 5 | def increment(): - | - - info[references]: Reference 2 - --> main.py:6:18 - | - 5 | def increment(): - 6 | nonlocal counter - | ^^^^^^^ - 7 | counter += 1 - 8 | return counter - | - - info[references]: Reference 3 - --> main.py:7:9 - | - 5 | def increment(): - 6 | nonlocal counter - 7 | counter += 1 - | ^^^^^^^ - 8 | return counter - | - - info[references]: Reference 4 - --> main.py:8:16 - | - 6 | nonlocal counter - 7 | counter += 1 - 8 | return counter - | ^^^^^^^ - 9 | - 10 | def decrement(): - | - - info[references]: Reference 5 - --> main.py:11:18 - | - 10 | def decrement(): - 11 | nonlocal counter - | ^^^^^^^ - 12 | counter -= 1 - 13 | return counter - | - - info[references]: Reference 6 - --> main.py:12:9 - | - 10 | def decrement(): - 11 | nonlocal counter - 12 | counter -= 1 - | ^^^^^^^ - 13 | return counter - | - - info[references]: Reference 7 - --> main.py:13:16 - | - 11 | nonlocal counter - 12 | counter -= 1 - 13 | return counter - | ^^^^^^^ - 14 | - 15 | # Use counter in outer scope - | - - info[references]: Reference 8 - --> main.py:16:15 - | - 15 | # Use counter in outer scope - 16 | initial = counter - | ^^^^^^^ - 17 | increment() - 18 | decrement() - | - - info[references]: Reference 9 - --> main.py:19:13 - | - 17 | increment() - 18 | decrement() - 19 | final = counter - | ^^^^^^^ - 20 | - 21 | return increment, decrement - | - "); - } - - #[test] - #[ignore] // TODO: Enable when global support is fully implemented in goto.rs - fn test_global_variable_references() { - let test = cursor_test( - " -global_counter = 0 - -def increment_global(): - global global_counter - global_counter += 1 - return global_counter - -def decrement_global(): - global global_counter - global_counter -= 1 - return global_counter - -# Use global_counter at module level -initial_value = global_counter -increment_global() -decrement_global() -final_value = global_counter -", - ); - - assert_snapshot!(test.references(), @r" - info[references]: Reference 1 - --> main.py:2:1 - | - 2 | global_counter = 0 - | ^^^^^^^^^^^^^^ - 3 | - 4 | def increment_global(): - | - - info[references]: Reference 2 - --> main.py:5:12 - | - 4 | def increment_global(): - 5 | global global_counter - | ^^^^^^^^^^^^^^ - 6 | global_counter += 1 - 7 | return global_counter - | - - info[references]: Reference 3 - --> main.py:6:5 - | - 4 | def increment_global(): - 5 | global global_counter - 6 | global_counter += 1 - | ^^^^^^^^^^^^^^ - 7 | return global_counter - | - - info[references]: Reference 4 - --> main.py:7:12 - | - 5 | global global_counter - 6 | global_counter += 1 - 7 | return global_counter - | ^^^^^^^^^^^^^^ - 8 | - 9 | def decrement_global(): - | - - info[references]: Reference 5 - --> main.py:10:12 - | - 9 | def decrement_global(): - 10 | global global_counter - | ^^^^^^^^^^^^^^ - 11 | global_counter -= 1 - 12 | return global_counter - | - - info[references]: Reference 6 - --> main.py:11:5 - | - 9 | def decrement_global(): - 10 | global global_counter - 11 | global_counter -= 1 - | ^^^^^^^^^^^^^^ - 12 | return global_counter - | - - info[references]: Reference 7 - --> main.py:12:12 - | - 10 | global global_counter - 11 | global_counter -= 1 - 12 | return global_counter - | ^^^^^^^^^^^^^^ - 13 | - 14 | # Use global_counter at module level - | - - info[references]: Reference 8 - --> main.py:15:17 - | - 14 | # Use global_counter at module level - 15 | initial_value = global_counter - | ^^^^^^^^^^^^^^ - 16 | increment_global() - 17 | decrement_global() - | - - info[references]: Reference 9 - --> main.py:18:15 - | - 16 | increment_global() - 17 | decrement_global() - 18 | final_value = global_counter - | ^^^^^^^^^^^^^^ - | - "); - } - - #[test] - fn test_except_handler_variable_references() { - let test = cursor_test( - " -try: - x = 1 / 0 -except ZeroDivisionError as err: - print(f'Error: {err}') - return err - -try: - y = 2 / 0 -except ValueError as err: - print(f'Different error: {err}') -", - ); - - // Note: Currently only finds the declaration, not the usages - // This is because semantic analysis for except handler variables isn't fully implemented - assert_snapshot!(test.references(), @r###" - info[references]: Reference 1 - --> main.py:4:29 - | - 2 | try: - 3 | x = 1 / 0 - 4 | except ZeroDivisionError as err: - | ^^^ - 5 | print(f'Error: {err}') - 6 | return err - | - "###); - } - - #[test] - fn test_pattern_match_as_references() { - let test = cursor_test( - " -match x: - case [a, b] as pattern: - print(f'Matched: {pattern}') - return pattern - case _: - pass -", - ); - - assert_snapshot!(test.references(), @r###" - info[references]: Reference 1 - --> main.py:3:20 - | - 2 | match x: - 3 | case [a, b] as pattern: - | ^^^^^^^ - 4 | print(f'Matched: {pattern}') - 5 | return pattern - | - - info[references]: Reference 2 - --> main.py:4:27 - | - 2 | match x: - 3 | case [a, b] as pattern: - 4 | print(f'Matched: {pattern}') - | ^^^^^^^ - 5 | return pattern - 6 | case _: - | - - info[references]: Reference 3 - --> main.py:5:16 - | - 3 | case [a, b] as pattern: - 4 | print(f'Matched: {pattern}') - 5 | return pattern - | ^^^^^^^ - 6 | case _: - 7 | pass - | - "###); - } - - #[test] - fn test_pattern_match_mapping_rest_references() { - let test = cursor_test( - " -match data: - case {'a': a, 'b': b, **rest}: - print(f'Rest data: {rest}') - process(rest) - return rest -", - ); - - assert_snapshot!(test.references(), @r###" - info[references]: Reference 1 - --> main.py:3:29 - | - 2 | match data: - 3 | case {'a': a, 'b': b, **rest}: - | ^^^^ - 4 | print(f'Rest data: {rest}') - 5 | process(rest) - | - - info[references]: Reference 2 - --> main.py:4:29 - | - 2 | match data: - 3 | case {'a': a, 'b': b, **rest}: - 4 | print(f'Rest data: {rest}') - | ^^^^ - 5 | process(rest) - 6 | return rest - | - - info[references]: Reference 3 - --> main.py:5:17 - | - 3 | case {'a': a, 'b': b, **rest}: - 4 | print(f'Rest data: {rest}') - 5 | process(rest) - | ^^^^ - 6 | return rest - | - - info[references]: Reference 4 - --> main.py:6:16 - | - 4 | print(f'Rest data: {rest}') - 5 | process(rest) - 6 | return rest - | ^^^^ - | - "###); - } - - #[test] - fn test_function_definition_references() { - let test = cursor_test( - " -def my_function(): - return 42 - -# Call the function multiple times -result1 = my_function() -result2 = my_function() - -# Function passed as an argument -callback = my_function - -# Function used in different contexts -print(my_function()) -value = my_function -", - ); - - assert_snapshot!(test.references(), @r" - info[references]: Reference 1 - --> main.py:2:5 - | - 2 | def my_function(): - | ^^^^^^^^^^^ - 3 | return 42 - | - - info[references]: Reference 2 - --> main.py:6:11 - | - 5 | # Call the function multiple times - 6 | result1 = my_function() - | ^^^^^^^^^^^ - 7 | result2 = my_function() - | - - info[references]: Reference 3 - --> main.py:7:11 - | - 5 | # Call the function multiple times - 6 | result1 = my_function() - 7 | result2 = my_function() - | ^^^^^^^^^^^ - 8 | - 9 | # Function passed as an argument - | - - info[references]: Reference 4 - --> main.py:10:12 - | - 9 | # Function passed as an argument - 10 | callback = my_function - | ^^^^^^^^^^^ - 11 | - 12 | # Function used in different contexts - | - - info[references]: Reference 5 - --> main.py:13:7 - | - 12 | # Function used in different contexts - 13 | print(my_function()) - | ^^^^^^^^^^^ - 14 | value = my_function - | - - info[references]: Reference 6 - --> main.py:14:9 - | - 12 | # Function used in different contexts - 13 | print(my_function()) - 14 | value = my_function - | ^^^^^^^^^^^ - | - "); - } - - #[test] - fn test_class_definition_references() { - let test = cursor_test( - " -class MyClass: - def __init__(self): - pass - -# Create instances -obj1 = MyClass() -obj2 = MyClass() - -# Use in type annotations -def process(instance: MyClass) -> MyClass: - return instance - -# Reference the class itself -cls = MyClass -", - ); - - assert_snapshot!(test.references(), @r" - info[references]: Reference 1 - --> main.py:2:7 - | - 2 | class MyClass: - | ^^^^^^^ - 3 | def __init__(self): - 4 | pass - | - - info[references]: Reference 2 - --> main.py:7:8 - | - 6 | # Create instances - 7 | obj1 = MyClass() - | ^^^^^^^ - 8 | obj2 = MyClass() - | - - info[references]: Reference 3 - --> main.py:8:8 - | - 6 | # Create instances - 7 | obj1 = MyClass() - 8 | obj2 = MyClass() - | ^^^^^^^ - 9 | - 10 | # Use in type annotations - | - - info[references]: Reference 4 - --> main.py:11:23 - | - 10 | # Use in type annotations - 11 | def process(instance: MyClass) -> MyClass: - | ^^^^^^^ - 12 | return instance - | - - info[references]: Reference 5 - --> main.py:11:35 - | - 10 | # Use in type annotations - 11 | def process(instance: MyClass) -> MyClass: - | ^^^^^^^ - 12 | return instance - | - - info[references]: Reference 6 - --> main.py:15:7 - | - 14 | # Reference the class itself - 15 | cls = MyClass - | ^^^^^^^ - | - "); - } - - #[test] - fn test_multi_file_function_references() { - let test = CursorTest::builder() - .source( - "utils.py", - " -def helper_function(x): - return x * 2 -", - ) - .source( - "module.py", - " -from utils import helper_function - -def process_data(data): - return helper_function(data) - -def double_process(data): - result = helper_function(data) - return helper_function(result) -", - ) - .source( - "app.py", - " -from utils import helper_function - -class DataProcessor: - def __init__(self): - self.multiplier = helper_function - - def process(self, value): - return helper_function(value) -", - ) - .build(); - - assert_snapshot!(test.references(), @r" - info[references]: Reference 1 - --> utils.py:2:5 - | - 2 | def helper_function(x): - | ^^^^^^^^^^^^^^^ - 3 | return x * 2 - | - - info[references]: Reference 2 - --> module.py:5:12 - | - 4 | def process_data(data): - 5 | return helper_function(data) - | ^^^^^^^^^^^^^^^ - 6 | - 7 | def double_process(data): - | - - info[references]: Reference 3 - --> module.py:8:14 - | - 7 | def double_process(data): - 8 | result = helper_function(data) - | ^^^^^^^^^^^^^^^ - 9 | return helper_function(result) - | - - info[references]: Reference 4 - --> module.py:9:12 - | - 7 | def double_process(data): - 8 | result = helper_function(data) - 9 | return helper_function(result) - | ^^^^^^^^^^^^^^^ - | - - info[references]: Reference 5 - --> app.py:6:27 - | - 4 | class DataProcessor: - 5 | def __init__(self): - 6 | self.multiplier = helper_function - | ^^^^^^^^^^^^^^^ - 7 | - 8 | def process(self, value): - | - - info[references]: Reference 6 - --> app.py:9:16 - | - 8 | def process(self, value): - 9 | return helper_function(value) - | ^^^^^^^^^^^^^^^ - | - "); - } - - #[test] - fn test_multi_file_class_attribute_references() { - let test = CursorTest::builder() - .source( - "models.py", - " -class MyModel: - attr = 42 - - def get_attribute(self): - return MyModel.attr -", - ) - .source( - "main.py", - " -from models import MyModel - -def process_model(): - model = MyModel() - value = model.attr - model.attr = 100 - return model.attr -", - ) - .build(); - - assert_snapshot!(test.references(), @r" - info[references]: Reference 1 - --> models.py:3:5 - | - 2 | class MyModel: - 3 | attr = 42 - | ^^^^ - 4 | - 5 | def get_attribute(self): - | - - info[references]: Reference 2 - --> models.py:6:24 - | - 5 | def get_attribute(self): - 6 | return MyModel.attr - | ^^^^ - | - - info[references]: Reference 3 - --> main.py:6:19 - | - 4 | def process_model(): - 5 | model = MyModel() - 6 | value = model.attr - | ^^^^ - 7 | model.attr = 100 - 8 | return model.attr - | - - info[references]: Reference 4 - --> main.py:7:11 - | - 5 | model = MyModel() - 6 | value = model.attr - 7 | model.attr = 100 - | ^^^^ - 8 | return model.attr - | - - info[references]: Reference 5 - --> main.py:8:18 - | - 6 | value = model.attr - 7 | model.attr = 100 - 8 | return model.attr - | ^^^^ - | - "); + /// Helper to check if an expression contains a given range + fn expr_contains_range(expr: &ast::Expr, range: TextRange) -> bool { + expr.range().contains_range(range) } } diff --git a/crates/ty_server/src/document/location.rs b/crates/ty_server/src/document/location.rs index 467f35d8cf..d5924595b2 100644 --- a/crates/ty_server/src/document/location.rs +++ b/crates/ty_server/src/document/location.rs @@ -5,7 +5,7 @@ use lsp_types::Location; use ruff_db::files::FileRange; use ruff_db::source::{line_index, source_text}; use ruff_text_size::Ranged; -use ty_ide::NavigationTarget; +use ty_ide::{NavigationTarget, ReferenceTarget}; use ty_project::Db; pub(crate) trait ToLink { @@ -53,3 +53,37 @@ impl ToLink for NavigationTarget { }) } } + +impl ToLink for ReferenceTarget { + fn to_location(&self, db: &dyn Db, encoding: PositionEncoding) -> Option { + self.file_range().to_location(db, encoding) + } + + fn to_link( + &self, + db: &dyn Db, + src: Option, + encoding: PositionEncoding, + ) -> Option { + let uri = file_to_url(db, self.file())?; + let source = source_text(db, self.file()); + let index = line_index(db, self.file()); + + let target_range = self.range().to_lsp_range(&source, &index, encoding); + let selection_range = target_range; + + let src = src.map(|src| { + let source = source_text(db, src.file()); + let index = line_index(db, src.file()); + + src.range().to_lsp_range(&source, &index, encoding) + }); + + Some(lsp_types::LocationLink { + target_uri: uri, + target_range, + target_selection_range: selection_range, + origin_selection_range: src, + }) + } +} diff --git a/crates/ty_server/src/server.rs b/crates/ty_server/src/server.rs index 29a34d8a23..2ae64d7c9e 100644 --- a/crates/ty_server/src/server.rs +++ b/crates/ty_server/src/server.rs @@ -210,6 +210,7 @@ impl Server { definition_provider: Some(lsp_types::OneOf::Left(true)), declaration_provider: Some(DeclarationCapability::Simple(true)), references_provider: Some(lsp_types::OneOf::Left(true)), + document_highlight_provider: Some(lsp_types::OneOf::Left(true)), hover_provider: Some(HoverProviderCapability::Simple(true)), signature_help_provider: Some(SignatureHelpOptions { trigger_characters: Some(vec!["(".to_string(), ",".to_string()]), diff --git a/crates/ty_server/src/server/api.rs b/crates/ty_server/src/server/api.rs index 062fc0e103..47b7c508c1 100644 --- a/crates/ty_server/src/server/api.rs +++ b/crates/ty_server/src/server/api.rs @@ -59,6 +59,11 @@ pub(super) fn request(req: server::Request) -> Task { requests::ReferencesRequestHandler::METHOD => background_document_request_task::< requests::ReferencesRequestHandler, >(req, BackgroundSchedule::Worker), + requests::DocumentHighlightRequestHandler::METHOD => background_document_request_task::< + requests::DocumentHighlightRequestHandler, + >( + req, BackgroundSchedule::Worker + ), requests::InlayHintRequestHandler::METHOD => background_document_request_task::< requests::InlayHintRequestHandler, >(req, BackgroundSchedule::Worker), diff --git a/crates/ty_server/src/server/api/requests.rs b/crates/ty_server/src/server/api/requests.rs index ff7772c35d..71d91dba0f 100644 --- a/crates/ty_server/src/server/api/requests.rs +++ b/crates/ty_server/src/server/api/requests.rs @@ -1,11 +1,12 @@ mod completion; mod diagnostic; +mod doc_highlights; mod goto_declaration; mod goto_definition; +mod goto_references; mod goto_type_definition; mod hover; mod inlay_hints; -mod references; mod semantic_tokens; mod semantic_tokens_range; mod shutdown; @@ -14,12 +15,13 @@ mod workspace_diagnostic; pub(super) use completion::CompletionRequestHandler; pub(super) use diagnostic::DocumentDiagnosticRequestHandler; +pub(super) use doc_highlights::DocumentHighlightRequestHandler; pub(super) use goto_declaration::GotoDeclarationRequestHandler; pub(super) use goto_definition::GotoDefinitionRequestHandler; +pub(super) use goto_references::ReferencesRequestHandler; pub(super) use goto_type_definition::GotoTypeDefinitionRequestHandler; pub(super) use hover::HoverRequestHandler; pub(super) use inlay_hints::InlayHintRequestHandler; -pub(super) use references::ReferencesRequestHandler; pub(super) use semantic_tokens::SemanticTokensRequestHandler; pub(super) use semantic_tokens_range::SemanticTokensRangeRequestHandler; pub(super) use shutdown::ShutdownHandler; diff --git a/crates/ty_server/src/server/api/requests/doc_highlights.rs b/crates/ty_server/src/server/api/requests/doc_highlights.rs new file mode 100644 index 0000000000..c9bb2a5e32 --- /dev/null +++ b/crates/ty_server/src/server/api/requests/doc_highlights.rs @@ -0,0 +1,74 @@ +use std::borrow::Cow; + +use lsp_types::request::DocumentHighlightRequest; +use lsp_types::{DocumentHighlight, DocumentHighlightKind, DocumentHighlightParams, Url}; +use ruff_db::source::{line_index, source_text}; +use ty_ide::{ReferenceKind, document_highlights}; +use ty_project::ProjectDatabase; + +use crate::document::{PositionExt, ToRangeExt}; +use crate::server::api::traits::{ + BackgroundDocumentRequestHandler, RequestHandler, RetriableRequestHandler, +}; +use crate::session::DocumentSnapshot; +use crate::session::client::Client; + +pub(crate) struct DocumentHighlightRequestHandler; + +impl RequestHandler for DocumentHighlightRequestHandler { + type RequestType = DocumentHighlightRequest; +} + +impl BackgroundDocumentRequestHandler for DocumentHighlightRequestHandler { + fn document_url(params: &DocumentHighlightParams) -> Cow { + Cow::Borrowed(¶ms.text_document_position_params.text_document.uri) + } + + fn run_with_snapshot( + db: &ProjectDatabase, + snapshot: DocumentSnapshot, + _client: &Client, + params: DocumentHighlightParams, + ) -> crate::server::Result>> { + if snapshot.client_settings().is_language_services_disabled() { + return Ok(None); + } + + let Some(file) = snapshot.file(db) else { + return Ok(None); + }; + + let source = source_text(db, file); + let line_index = line_index(db, file); + let offset = params.text_document_position_params.position.to_text_size( + &source, + &line_index, + snapshot.encoding(), + ); + + let Some(highlights_result) = document_highlights(db, file, offset) else { + return Ok(None); + }; + + let highlights: Vec<_> = highlights_result + .into_iter() + .map(|target| { + let range = target + .range() + .to_lsp_range(&source, &line_index, snapshot.encoding()); + + let kind = match target.kind() { + ReferenceKind::Read => Some(DocumentHighlightKind::READ), + ReferenceKind::Write => Some(DocumentHighlightKind::WRITE), + ReferenceKind::Other => Some(DocumentHighlightKind::TEXT), + }; + + DocumentHighlight { range, kind } + }) + .collect(); + + Ok(Some(highlights)) + } +} + +impl RetriableRequestHandler for DocumentHighlightRequestHandler {} diff --git a/crates/ty_server/src/server/api/requests/references.rs b/crates/ty_server/src/server/api/requests/goto_references.rs similarity index 91% rename from crates/ty_server/src/server/api/requests/references.rs rename to crates/ty_server/src/server/api/requests/goto_references.rs index 63c713ac3e..f630d7b815 100644 --- a/crates/ty_server/src/server/api/requests/references.rs +++ b/crates/ty_server/src/server/api/requests/goto_references.rs @@ -3,7 +3,7 @@ use std::borrow::Cow; use lsp_types::request::References; use lsp_types::{Location, ReferenceParams, Url}; use ruff_db::source::{line_index, source_text}; -use ty_ide::references; +use ty_ide::goto_references; use ty_project::ProjectDatabase; use crate::document::{PositionExt, ToLink}; @@ -48,13 +48,12 @@ impl BackgroundDocumentRequestHandler for ReferencesRequestHandler { let include_declaration = params.context.include_declaration; - let Some(references_result) = references(db, file, offset, include_declaration) else { + let Some(references_result) = goto_references(db, file, offset, include_declaration) else { return Ok(None); }; let locations: Vec<_> = references_result .into_iter() - .flat_map(|ranged| ranged.value.into_iter()) .filter_map(|target| target.to_location(db, snapshot.encoding())) .collect(); diff --git a/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization.snap b/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization.snap index 8dcfd73469..c137fa8693 100644 --- a/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization.snap +++ b/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization.snap @@ -27,6 +27,7 @@ expression: initialization_result "definitionProvider": true, "typeDefinitionProvider": true, "referencesProvider": true, + "documentHighlightProvider": true, "declarationProvider": true, "semanticTokensProvider": { "legend": { diff --git a/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization_with_workspace.snap b/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization_with_workspace.snap index 8dcfd73469..c137fa8693 100644 --- a/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization_with_workspace.snap +++ b/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization_with_workspace.snap @@ -27,6 +27,7 @@ expression: initialization_result "definitionProvider": true, "typeDefinitionProvider": true, "referencesProvider": true, + "documentHighlightProvider": true, "declarationProvider": true, "semanticTokensProvider": { "legend": { diff --git a/crates/ty_wasm/src/lib.rs b/crates/ty_wasm/src/lib.rs index 66f7e23bec..cfc7691f4c 100644 --- a/crates/ty_wasm/src/lib.rs +++ b/crates/ty_wasm/src/lib.rs @@ -15,10 +15,9 @@ use ruff_python_formatter::formatted_file; use ruff_source_file::{LineIndex, OneIndexed, SourceLocation}; use ruff_text_size::{Ranged, TextSize}; use ty_ide::{ - MarkupKind, RangedValue, goto_declaration, goto_definition, goto_type_definition, hover, - inlay_hints, references, + MarkupKind, NavigationTargets, RangedValue, goto_declaration, goto_definition, goto_references, + goto_type_definition, hover, inlay_hints, signature_help, }; -use ty_ide::{NavigationTargets, signature_help}; use ty_project::metadata::options::Options; use ty_project::metadata::value::ValueSource; use ty_project::watch::{ChangeEvent, ChangedKind, CreatedKind, DeletedKind}; @@ -338,14 +337,30 @@ impl Workspace { let offset = position.to_text_size(&source, &index, self.position_encoding)?; - let Some(targets) = references(&self.db, file_id.file, offset, true) else { + let Some(targets) = goto_references(&self.db, file_id.file, offset, true) else { return Ok(Vec::new()); }; Ok(targets .into_iter() - .flat_map(|target| { - map_targets_to_links(&self.db, target, &source, &index, self.position_encoding) + .map(|target| LocationLink { + path: target.file().path(&self.db).to_string(), + full_range: Range::from_file_range( + &self.db, + target.file_range(), + self.position_encoding, + ), + selection_range: Some(Range::from_file_range( + &self.db, + target.file_range(), + self.position_encoding, + )), + origin_selection_range: Some(Range::from_text_range( + ruff_text_size::TextRange::new(offset, offset), + &index, + &source, + self.position_encoding, + )), }) .collect()) }