diff --git a/Cargo.lock b/Cargo.lock index 64ef7eb620..806f4d1ac9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2507,10 +2507,15 @@ dependencies = [ name = "red_knot_ide" version = "0.0.0" dependencies = [ + "insta", "red_knot_python_semantic", "red_knot_vendored", "ruff_db", + "ruff_python_ast", + "ruff_python_parser", + "ruff_text_size", "salsa", + "smallvec", "tracing", ] @@ -2597,7 +2602,9 @@ dependencies = [ "libc", "lsp-server", "lsp-types", + "red_knot_ide", "red_knot_project", + "red_knot_python_semantic", "ruff_db", "ruff_notebook", "ruff_python_ast", diff --git a/crates/red_knot_ide/Cargo.toml b/crates/red_knot_ide/Cargo.toml index 184154a9b1..27117a6b2e 100644 --- a/crates/red_knot_ide/Cargo.toml +++ b/crates/red_knot_ide/Cargo.toml @@ -12,13 +12,19 @@ license = { workspace = true } [dependencies] ruff_db = { workspace = true } +ruff_python_ast = { workspace = true } +ruff_python_parser = { workspace = true } +ruff_text_size = { workspace = true } red_knot_python_semantic = { workspace = true } salsa = { workspace = true } +smallvec = { workspace = true } tracing = { workspace = true } [dev-dependencies] red_knot_vendored = { workspace = true } +insta = { workspace = true, features = ["filters"] } + [lints] workspace = true diff --git a/crates/red_knot_ide/src/db.rs b/crates/red_knot_ide/src/db.rs index fff5b7d7dc..ba49ee864f 100644 --- a/crates/red_knot_ide/src/db.rs +++ b/crates/red_knot_ide/src/db.rs @@ -2,7 +2,7 @@ use red_knot_python_semantic::Db as SemanticDb; use ruff_db::{Db as SourceDb, Upcast}; #[salsa::db] -pub trait Db: SemanticDb + Upcast {} +pub trait Db: SemanticDb + Upcast + Upcast {} #[cfg(test)] pub(crate) mod tests { @@ -94,6 +94,16 @@ pub(crate) mod tests { } } + impl Upcast for TestDb { + fn upcast(&self) -> &(dyn SemanticDb + 'static) { + self + } + + fn upcast_mut(&mut self) -> &mut dyn SemanticDb { + self + } + } + #[salsa::db] impl SemanticDb for TestDb { fn is_file_open(&self, file: File) -> bool { diff --git a/crates/red_knot_ide/src/find_node.rs b/crates/red_knot_ide/src/find_node.rs new file mode 100644 index 0000000000..c3ea78d3d6 --- /dev/null +++ b/crates/red_knot_ide/src/find_node.rs @@ -0,0 +1,106 @@ +use ruff_python_ast::visitor::source_order::{SourceOrderVisitor, TraversalSignal}; +use ruff_python_ast::AnyNodeRef; +use ruff_text_size::{Ranged, TextRange}; +use std::fmt; +use std::fmt::Formatter; + +/// Returns the node with a minimal range that fully contains `range`. +/// +/// If `range` is empty and falls within a parser *synthesized* node generated during error recovery, +/// then the first node with the given range is returned. +/// +/// ## Panics +/// Panics if `range` is not contained within `root`. +pub(crate) fn covering_node(root: AnyNodeRef, range: TextRange) -> CoveringNode { + struct Visitor<'a> { + range: TextRange, + found: bool, + ancestors: Vec>, + } + + impl<'a> SourceOrderVisitor<'a> for Visitor<'a> { + fn enter_node(&mut self, node: AnyNodeRef<'a>) -> TraversalSignal { + // If the node fully contains the range, than it is a possible match but traverse into its children + // to see if there's a node with a narrower range. + if !self.found && node.range().contains_range(self.range) { + self.ancestors.push(node); + TraversalSignal::Traverse + } else { + TraversalSignal::Skip + } + } + + fn leave_node(&mut self, node: AnyNodeRef<'a>) { + if !self.found && self.ancestors.last() == Some(&node) { + self.found = true; + } + } + } + + assert!( + root.range().contains_range(range), + "Range is not contained within root" + ); + + let mut visitor = Visitor { + range, + found: false, + ancestors: Vec::new(), + }; + + root.visit_source_order(&mut visitor); + + let minimal = visitor.ancestors.pop().unwrap_or(root); + CoveringNode { + node: minimal, + ancestors: visitor.ancestors, + } +} + +/// The node with a minimal range that fully contains the search range. +pub(crate) struct CoveringNode<'a> { + /// The node with a minimal range that fully contains the search range. + node: AnyNodeRef<'a>, + + /// The node's ancestor (the spine up to the root). + ancestors: Vec>, +} + +impl<'a> CoveringNode<'a> { + pub(crate) fn node(&self) -> AnyNodeRef<'a> { + self.node + } + + /// Returns the node's parent. + pub(crate) fn parent(&self) -> Option> { + self.ancestors.last().copied() + } + + /// Finds the minimal node that fully covers the range and fulfills the given predicate. + pub(crate) fn find(mut self, f: impl Fn(AnyNodeRef<'a>) -> bool) -> Result { + if f(self.node) { + return Ok(self); + } + + match self.ancestors.iter().rposition(|node| f(*node)) { + Some(index) => { + let node = self.ancestors[index]; + self.ancestors.truncate(index); + + Ok(Self { + node, + ancestors: self.ancestors, + }) + } + None => Err(self), + } + } +} + +impl fmt::Debug for CoveringNode<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_tuple("NodeWithAncestors") + .field(&self.node) + .finish() + } +} diff --git a/crates/red_knot_ide/src/goto.rs b/crates/red_knot_ide/src/goto.rs new file mode 100644 index 0000000000..36c49fb7f2 --- /dev/null +++ b/crates/red_knot_ide/src/goto.rs @@ -0,0 +1,914 @@ +use crate::find_node::covering_node; +use crate::{Db, HasNavigationTargets, NavigationTargets, RangedValue}; +use red_knot_python_semantic::{HasType, SemanticModel}; +use ruff_db::files::{File, FileRange}; +use ruff_db::parsed::{parsed_module, ParsedModule}; +use ruff_python_ast::{self as ast, AnyNodeRef}; +use ruff_python_parser::TokenKind; +use ruff_text_size::{Ranged, TextRange, TextSize}; + +pub fn goto_type_definition( + db: &dyn Db, + file: File, + offset: TextSize, +) -> Option> { + let parsed = parsed_module(db.upcast(), file); + let goto_target = find_goto_target(parsed, offset)?; + + let model = SemanticModel::new(db.upcast(), file); + + let ty = match goto_target { + GotoTarget::Expression(expression) => expression.inferred_type(&model), + GotoTarget::FunctionDef(function) => function.inferred_type(&model), + GotoTarget::ClassDef(class) => class.inferred_type(&model), + GotoTarget::Parameter(parameter) => parameter.inferred_type(&model), + GotoTarget::Alias(alias) => alias.inferred_type(&model), + GotoTarget::ExceptVariable(except) => except.inferred_type(&model), + GotoTarget::KeywordArgument(argument) => { + // TODO: Pyright resolves the declared type of the matching parameter. This seems more accurate + // than using the inferred value. + argument.value.inferred_type(&model) + } + // TODO: Support identifier targets + GotoTarget::PatternMatchRest(_) + | GotoTarget::PatternKeywordArgument(_) + | GotoTarget::PatternMatchStarName(_) + | GotoTarget::PatternMatchAsName(_) + | GotoTarget::ImportedModule(_) + | GotoTarget::TypeParamTypeVarName(_) + | GotoTarget::TypeParamParamSpecName(_) + | GotoTarget::TypeParamTypeVarTupleName(_) + | GotoTarget::NonLocal { .. } + | GotoTarget::Globals { .. } => return None, + }; + + tracing::debug!( + "Inferred type of covering node is {}", + ty.display(db.upcast()) + ); + + Some(RangedValue { + range: FileRange::new(file, goto_target.range()), + value: ty.navigation_targets(db), + }) +} + +#[derive(Clone, Copy, Debug)] +pub(crate) enum GotoTarget<'a> { + Expression(ast::ExprRef<'a>), + FunctionDef(&'a ast::StmtFunctionDef), + ClassDef(&'a ast::StmtClassDef), + Parameter(&'a ast::Parameter), + Alias(&'a ast::Alias), + + /// Go to on the module name of an import from + /// ```py + /// from foo import bar + /// ^^^ + /// ``` + ImportedModule(&'a ast::StmtImportFrom), + + /// Go to on the exception handler variable + /// ```py + /// try: ... + /// except Exception as e: ... + /// ^ + /// ``` + ExceptVariable(&'a ast::ExceptHandlerExceptHandler), + + /// Go to on a keyword argument + /// ```py + /// test(a = 1) + /// ^ + /// ``` + KeywordArgument(&'a ast::Keyword), + + /// Go to on the rest parameter of a pattern match + /// + /// ```py + /// match x: + /// case {"a": a, "b": b, **rest}: ... + /// ^^^^ + /// ``` + PatternMatchRest(&'a ast::PatternMatchMapping), + + /// Go to on a keyword argument of a class pattern + /// + /// ```py + /// match Point3D(0, 0, 0): + /// case Point3D(x=0, y=0, z=0): ... + /// ^ ^ ^ + /// ``` + PatternKeywordArgument(&'a ast::PatternKeyword), + + /// Go to on a pattern star argument + /// + /// ```py + /// match array: + /// case [*args]: ... + /// ^^^^ + PatternMatchStarName(&'a ast::PatternMatchStar), + + /// Go to on the name of a pattern match as pattern + /// + /// ```py + /// match x: + /// case [x] as y: ... + /// ^ + PatternMatchAsName(&'a ast::PatternMatchAs), + + /// Go to on the name of a type variable + /// + /// ```py + /// type Alias[T: int = bool] = list[T] + /// ^ + /// ``` + TypeParamTypeVarName(&'a ast::TypeParamTypeVar), + + /// Go to on the name of a type param spec + /// + /// ```py + /// type Alias[**P = [int, str]] = Callable[P, int] + /// ^ + /// ``` + TypeParamParamSpecName(&'a ast::TypeParamParamSpec), + + /// Go to on the name of a type var tuple + /// + /// ```py + /// type Alias[*Ts = ()] = tuple[*Ts] + /// ^^ + /// ``` + TypeParamTypeVarTupleName(&'a ast::TypeParamTypeVarTuple), + + NonLocal { + identifier: &'a ast::Identifier, + }, + Globals { + identifier: &'a ast::Identifier, + }, +} + +impl Ranged for GotoTarget<'_> { + fn range(&self) -> TextRange { + match self { + GotoTarget::Expression(expression) => expression.range(), + GotoTarget::FunctionDef(function) => function.name.range, + GotoTarget::ClassDef(class) => class.name.range, + GotoTarget::Parameter(parameter) => parameter.name.range, + GotoTarget::Alias(alias) => alias.name.range, + GotoTarget::ImportedModule(module) => module.module.as_ref().unwrap().range, + GotoTarget::ExceptVariable(except) => except.name.as_ref().unwrap().range, + GotoTarget::KeywordArgument(keyword) => keyword.arg.as_ref().unwrap().range, + GotoTarget::PatternMatchRest(rest) => rest.rest.as_ref().unwrap().range, + GotoTarget::PatternKeywordArgument(keyword) => keyword.attr.range, + GotoTarget::PatternMatchStarName(star) => star.name.as_ref().unwrap().range, + GotoTarget::PatternMatchAsName(as_name) => as_name.name.as_ref().unwrap().range, + GotoTarget::TypeParamTypeVarName(type_var) => type_var.name.range, + GotoTarget::TypeParamParamSpecName(spec) => spec.name.range, + GotoTarget::TypeParamTypeVarTupleName(tuple) => tuple.name.range, + GotoTarget::NonLocal { identifier, .. } => identifier.range, + GotoTarget::Globals { identifier, .. } => identifier.range, + } + } +} + +pub(crate) fn find_goto_target(parsed: &ParsedModule, offset: TextSize) -> Option { + let token = parsed.tokens().at_offset(offset).find(|token| { + matches!( + token.kind(), + TokenKind::Name + | TokenKind::String + | TokenKind::Complex + | TokenKind::Float + | TokenKind::Int + ) + })?; + let covering_node = covering_node(parsed.syntax().into(), token.range()) + .find(|node| node.is_identifier() || node.is_expression()) + .ok()?; + + tracing::trace!("Covering node is of kind {:?}", covering_node.node().kind()); + + match covering_node.node() { + AnyNodeRef::Identifier(identifier) => match covering_node.parent() { + Some(AnyNodeRef::StmtFunctionDef(function)) => Some(GotoTarget::FunctionDef(function)), + Some(AnyNodeRef::StmtClassDef(class)) => Some(GotoTarget::ClassDef(class)), + Some(AnyNodeRef::Parameter(parameter)) => Some(GotoTarget::Parameter(parameter)), + Some(AnyNodeRef::Alias(alias)) => Some(GotoTarget::Alias(alias)), + Some(AnyNodeRef::StmtImportFrom(from)) => Some(GotoTarget::ImportedModule(from)), + Some(AnyNodeRef::ExceptHandlerExceptHandler(handler)) => { + Some(GotoTarget::ExceptVariable(handler)) + } + Some(AnyNodeRef::Keyword(keyword)) => Some(GotoTarget::KeywordArgument(keyword)), + Some(AnyNodeRef::PatternMatchMapping(mapping)) => { + Some(GotoTarget::PatternMatchRest(mapping)) + } + Some(AnyNodeRef::PatternKeyword(keyword)) => { + Some(GotoTarget::PatternKeywordArgument(keyword)) + } + Some(AnyNodeRef::PatternMatchStar(star)) => { + Some(GotoTarget::PatternMatchStarName(star)) + } + Some(AnyNodeRef::PatternMatchAs(as_pattern)) => { + Some(GotoTarget::PatternMatchAsName(as_pattern)) + } + Some(AnyNodeRef::TypeParamTypeVar(var)) => Some(GotoTarget::TypeParamTypeVarName(var)), + Some(AnyNodeRef::TypeParamParamSpec(bound)) => { + Some(GotoTarget::TypeParamParamSpecName(bound)) + } + Some(AnyNodeRef::TypeParamTypeVarTuple(var_tuple)) => { + Some(GotoTarget::TypeParamTypeVarTupleName(var_tuple)) + } + Some(AnyNodeRef::ExprAttribute(attribute)) => { + Some(GotoTarget::Expression(attribute.into())) + } + Some(AnyNodeRef::StmtNonlocal(_)) => Some(GotoTarget::NonLocal { identifier }), + Some(AnyNodeRef::StmtGlobal(_)) => Some(GotoTarget::Globals { identifier }), + None => None, + Some(parent) => { + tracing::debug!( + "Missing `GoToTarget` for identifier with parent {:?}", + parent.kind() + ); + None + } + }, + + node => node.as_expr_ref().map(GotoTarget::Expression), + } +} + +#[cfg(test)] +mod tests { + + use crate::db::tests::TestDb; + use crate::{goto_type_definition, NavigationTarget}; + use insta::assert_snapshot; + use insta::internals::SettingsBindDropGuard; + use red_knot_python_semantic::{ + Program, ProgramSettings, PythonPath, PythonPlatform, SearchPathSettings, + }; + use ruff_db::diagnostic::{ + Annotation, Diagnostic, DiagnosticFormat, DiagnosticId, DisplayDiagnosticConfig, LintName, + Severity, Span, SubDiagnostic, + }; + use ruff_db::files::{system_path_to_file, File, FileRange}; + use ruff_db::system::{DbWithWritableSystem, SystemPath, SystemPathBuf}; + use ruff_python_ast::PythonVersion; + use ruff_text_size::{Ranged, TextSize}; + + #[test] + fn goto_type_of_expression_with_class_type() { + let test = goto_test( + r#" + class Test: ... + + ab = Test() + "#, + ); + + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> /main.py:2:19 + | + 2 | class Test: ... + | ^^^^ + 3 | + 4 | ab = Test() + | + info: Source + --> /main.py:4:13 + | + 2 | class Test: ... + 3 | + 4 | ab = Test() + | ^^ + | + "###); + } + + #[test] + fn goto_type_of_expression_with_function_type() { + let test = goto_test( + r#" + def foo(a, b): ... + + ab = foo + + ab + "#, + ); + + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> /main.py:2:17 + | + 2 | def foo(a, b): ... + | ^^^ + 3 | + 4 | ab = foo + | + info: Source + --> /main.py:6:13 + | + 4 | ab = foo + 5 | + 6 | ab + | ^^ + | + "###); + } + + #[test] + fn goto_type_of_expression_with_union_type() { + let test = goto_test( + r#" + + def foo(a, b): ... + + def bar(a, b): ... + + if random.choice(): + a = foo + else: + a = bar + + a + "#, + ); + + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> /main.py:3:17 + | + 3 | def foo(a, b): ... + | ^^^ + 4 | + 5 | def bar(a, b): ... + | + info: Source + --> /main.py:12:13 + | + 10 | a = bar + 11 | + 12 | a + | ^ + | + info: lint:goto-type-definition: Type definition + --> /main.py:5:17 + | + 3 | def foo(a, b): ... + 4 | + 5 | def bar(a, b): ... + | ^^^ + 6 | + 7 | if random.choice(): + | + info: Source + --> /main.py:12:13 + | + 10 | a = bar + 11 | + 12 | a + | ^ + | + "###); + } + + #[test] + fn goto_type_of_expression_with_module() { + let mut test = goto_test( + r#" + import lib + + lib + "#, + ); + + test.write_file("lib.py", "a = 10").unwrap(); + + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> /lib.py:1:1 + | + 1 | a = 10 + | ^ + | + info: Source + --> /main.py:4:13 + | + 2 | import lib + 3 | + 4 | lib + | ^^^ + | + "###); + } + + #[test] + fn goto_type_of_expression_with_literal_type() { + let test = goto_test( + r#" + a: str = "test" + + a + "#, + ); + + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> stdlib/builtins.pyi:443:7 + | + 441 | def __getitem__(self, key: int, /) -> str | int | None: ... + 442 | + 443 | class str(Sequence[str]): + | ^^^ + 444 | @overload + 445 | def __new__(cls, object: object = ...) -> Self: ... + | + info: Source + --> /main.py:4:13 + | + 2 | a: str = "test" + 3 | + 4 | a + | ^ + | + "###); + } + #[test] + fn goto_type_of_expression_with_literal_node() { + let test = goto_test( + r#" + a: str = "test" + "#, + ); + + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> stdlib/builtins.pyi:443:7 + | + 441 | def __getitem__(self, key: int, /) -> str | int | None: ... + 442 | + 443 | class str(Sequence[str]): + | ^^^ + 444 | @overload + 445 | def __new__(cls, object: object = ...) -> Self: ... + | + info: Source + --> /main.py:2:22 + | + 2 | a: str = "test" + | ^^^^^^ + | + "###); + } + + #[test] + fn goto_type_of_expression_with_type_var_type() { + let test = goto_test( + r#" + type Alias[T: int = bool] = list[T] + "#, + ); + + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> /main.py:2:24 + | + 2 | type Alias[T: int = bool] = list[T] + | ^ + | + info: Source + --> /main.py:2:46 + | + 2 | type Alias[T: int = bool] = list[T] + | ^ + | + "###); + } + + #[test] + fn goto_type_of_expression_with_type_param_spec() { + let test = goto_test( + r#" + type Alias[**P = [int, str]] = Callable[P, int] + "#, + ); + + // TODO: Goto type definition currently doesn't work for type param specs + // because the inference doesn't support them yet. + // This snapshot should show a single target pointing to `T` + assert_snapshot!(test.goto_type_definition(), @"No type definitions found"); + } + + #[test] + fn goto_type_of_expression_with_type_var_tuple() { + let test = goto_test( + r#" + type Alias[*Ts = ()] = tuple[*Ts] + "#, + ); + + // TODO: Goto type definition currently doesn't work for type var tuples + // because the inference doesn't support them yet. + // This snapshot should show a single target pointing to `T` + assert_snapshot!(test.goto_type_definition(), @"No type definitions found"); + } + + #[test] + fn goto_type_on_keyword_argument() { + let test = goto_test( + r#" + def test(a: str): ... + + test(a= "123") + "#, + ); + + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> stdlib/builtins.pyi:443:7 + | + 441 | def __getitem__(self, key: int, /) -> str | int | None: ... + 442 | + 443 | class str(Sequence[str]): + | ^^^ + 444 | @overload + 445 | def __new__(cls, object: object = ...) -> Self: ... + | + info: Source + --> /main.py:4:18 + | + 2 | def test(a: str): ... + 3 | + 4 | test(a= "123") + | ^ + | + "###); + } + + #[test] + fn goto_type_on_incorrectly_typed_keyword_argument() { + let test = goto_test( + r#" + def test(a: str): ... + + test(a= 123) + "#, + ); + + // TODO: This should jump to `str` and not `int` because + // the keyword is typed as a string. It's only the passed argument that + // is an int. Navigating to `str` would match pyright's behavior. + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> stdlib/builtins.pyi:234:7 + | + 232 | _LiteralInteger = _PositiveInteger | _NegativeInteger | Literal[0] # noqa: Y026 # TODO: Use TypeAlias once mypy bugs are fixed + 233 | + 234 | class int: + | ^^^ + 235 | @overload + 236 | def __new__(cls, x: ConvertibleToInt = ..., /) -> Self: ... + | + info: Source + --> /main.py:4:18 + | + 2 | def test(a: str): ... + 3 | + 4 | test(a= 123) + | ^ + | + "###); + } + + #[test] + fn goto_type_on_kwargs() { + let test = goto_test( + r#" + def f(name: str): ... + +kwargs = { "name": "test"} + +f(**kwargs) + "#, + ); + + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> stdlib/builtins.pyi:1098:7 + | + 1096 | def __class_getitem__(cls, item: Any, /) -> GenericAlias: ... + 1097 | + 1098 | class dict(MutableMapping[_KT, _VT]): + | ^^^^ + 1099 | # __init__ should be kept roughly in line with `collections.UserDict.__init__`, which has similar semantics + 1100 | # Also multiprocessing.managers.SyncManager.dict() + | + info: Source + --> /main.py:6:5 + | + 4 | kwargs = { "name": "test"} + 5 | + 6 | f(**kwargs) + | ^^^^^^ + | + "###); + } + + #[test] + fn goto_type_of_expression_with_builtin() { + let test = goto_test( + r#" + def foo(a: str): + a + "#, + ); + + // FIXME: This should go to `str` + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> stdlib/builtins.pyi:443:7 + | + 441 | def __getitem__(self, key: int, /) -> str | int | None: ... + 442 | + 443 | class str(Sequence[str]): + | ^^^ + 444 | @overload + 445 | def __new__(cls, object: object = ...) -> Self: ... + | + info: Source + --> /main.py:3:17 + | + 2 | def foo(a: str): + 3 | a + | ^ + | + "###); + } + + #[test] + fn goto_type_definition_cursor_between_object_and_attribute() { + let test = goto_test( + r#" + class X: + def foo(a, b): ... + + x = X() + + x.foo() + "#, + ); + + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> /main.py:2:19 + | + 2 | class X: + | ^ + 3 | def foo(a, b): ... + | + info: Source + --> /main.py:7:13 + | + 5 | x = X() + 6 | + 7 | x.foo() + | ^ + | + "###); + } + + #[test] + fn goto_between_call_arguments() { + let test = goto_test( + r#" + def foo(a, b): ... + + foo() + "#, + ); + + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> /main.py:2:17 + | + 2 | def foo(a, b): ... + | ^^^ + 3 | + 4 | foo() + | + info: Source + --> /main.py:4:13 + | + 2 | def foo(a, b): ... + 3 | + 4 | foo() + | ^^^ + | + "###); + } + + #[test] + fn goto_type_narrowing() { + let test = goto_test( + r#" + def foo(a: str | None, b): + if a is not None: + print(a) + "#, + ); + + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> stdlib/builtins.pyi:443:7 + | + 441 | def __getitem__(self, key: int, /) -> str | int | None: ... + 442 | + 443 | class str(Sequence[str]): + | ^^^ + 444 | @overload + 445 | def __new__(cls, object: object = ...) -> Self: ... + | + info: Source + --> /main.py:4:27 + | + 2 | def foo(a: str | None, b): + 3 | if a is not None: + 4 | print(a) + | ^ + | + "###); + } + + #[test] + fn goto_type_none() { + let test = goto_test( + r#" + def foo(a: str | None, b): + a + "#, + ); + + assert_snapshot!(test.goto_type_definition(), @r###" + info: lint:goto-type-definition: Type definition + --> stdlib/builtins.pyi:443:7 + | + 441 | def __getitem__(self, key: int, /) -> str | int | None: ... + 442 | + 443 | class str(Sequence[str]): + | ^^^ + 444 | @overload + 445 | def __new__(cls, object: object = ...) -> Self: ... + | + info: Source + --> /main.py:3:17 + | + 2 | def foo(a: str | None, b): + 3 | a + | ^ + | + info: lint:goto-type-definition: Type definition + --> stdlib/types.pyi:677:11 + | + 675 | if sys.version_info >= (3, 10): + 676 | @final + 677 | class NoneType: + | ^^^^^^^^ + 678 | def __bool__(self) -> Literal[False]: ... + | + info: Source + --> /main.py:3:17 + | + 2 | def foo(a: str | None, b): + 3 | a + | ^ + | + "###); + } + + fn goto_test(source: &str) -> GotoTest { + let mut db = TestDb::new(); + let cursor_offset = source.find("").expect( + "`source`` should contain a `` marker, indicating the position of the cursor.", + ); + + let mut content = source[..cursor_offset].to_string(); + content.push_str(&source[cursor_offset + "".len()..]); + + db.write_file("main.py", &content) + .expect("write to memory file system to be successful"); + + let file = system_path_to_file(&db, "main.py").expect("newly written file to existing"); + + Program::from_settings( + &db, + ProgramSettings { + python_version: PythonVersion::latest(), + python_platform: PythonPlatform::default(), + search_paths: SearchPathSettings { + extra_paths: vec![], + src_roots: vec![SystemPathBuf::from("/")], + custom_typeshed: None, + python_path: PythonPath::KnownSitePackages(vec![]), + }, + }, + ) + .expect("Default settings to be valid"); + + let mut insta_settings = insta::Settings::clone_current(); + insta_settings.add_filter(r#"\\(\w\w|\s|\.|")"#, "/$1"); + + let insta_settings_guard = insta_settings.bind_to_scope(); + + GotoTest { + db, + cursor_offset: TextSize::try_from(cursor_offset) + .expect("source to be smaller than 4GB"), + file, + _insta_settings_guard: insta_settings_guard, + } + } + + struct GotoTest { + db: TestDb, + cursor_offset: TextSize, + file: File, + _insta_settings_guard: SettingsBindDropGuard, + } + + impl GotoTest { + fn write_file( + &mut self, + path: impl AsRef, + content: &str, + ) -> std::io::Result<()> { + self.db.write_file(path, content) + } + + fn goto_type_definition(&self) -> String { + let Some(targets) = goto_type_definition(&self.db, self.file, self.cursor_offset) + else { + return "No goto target found".to_string(); + }; + + if targets.is_empty() { + return "No type definitions found".to_string(); + } + + let mut buf = vec![]; + + let source = targets.range; + + for target in &*targets { + GotoTypeDefinitionDiagnostic::new(source, target) + .into_diagnostic() + .print( + &self.db, + &DisplayDiagnosticConfig::default() + .color(false) + .format(DiagnosticFormat::Full), + &mut buf, + ) + .unwrap(); + } + + String::from_utf8(buf).unwrap() + } + } + + struct GotoTypeDefinitionDiagnostic { + source: FileRange, + target: FileRange, + } + + impl GotoTypeDefinitionDiagnostic { + fn new(source: FileRange, target: &NavigationTarget) -> Self { + Self { + source, + target: FileRange::new(target.file(), target.focus_range()), + } + } + + fn into_diagnostic(self) -> Diagnostic { + let mut source = SubDiagnostic::new(Severity::Info, "Source"); + source.annotate(Annotation::primary( + Span::from(self.source.file()).with_range(self.source.range()), + )); + + let mut main = Diagnostic::new( + DiagnosticId::Lint(LintName::of("goto-type-definition")), + Severity::Info, + "Type definition".to_string(), + ); + main.annotate(Annotation::primary( + Span::from(self.target.file()).with_range(self.target.range()), + )); + main.sub(source); + + main + } + } +} diff --git a/crates/red_knot_ide/src/lib.rs b/crates/red_knot_ide/src/lib.rs index 6f53e6094d..a2eaaac1e5 100644 --- a/crates/red_knot_ide/src/lib.rs +++ b/crates/red_knot_ide/src/lib.rs @@ -1,3 +1,257 @@ mod db; +mod find_node; +mod goto; + +use std::ops::{Deref, DerefMut}; pub use db::Db; +pub use goto::goto_type_definition; +use red_knot_python_semantic::types::{ + Class, ClassBase, ClassLiteralType, FunctionType, InstanceType, IntersectionType, + KnownInstanceType, ModuleLiteralType, Type, +}; +use ruff_db::files::{File, FileRange}; +use ruff_db::source::source_text; +use ruff_text_size::{Ranged, TextLen, TextRange}; + +/// Information associated with a text range. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub struct RangedValue { + pub range: FileRange, + pub value: T, +} + +impl Deref for RangedValue { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.value + } +} + +impl DerefMut for RangedValue { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.value + } +} + +impl IntoIterator for RangedValue +where + T: IntoIterator, +{ + type Item = T::Item; + type IntoIter = T::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.value.into_iter() + } +} + +/// Target to which the editor can navigate to. +#[derive(Debug, Clone)] +pub struct NavigationTarget { + file: File, + + /// The range that should be focused when navigating to the target. + /// + /// This is typically not the full range of the node. For example, it's the range of the class's name in a class definition. + /// + /// The `focus_range` must be fully covered by `full_range`. + focus_range: TextRange, + + /// The range covering the entire target. + full_range: TextRange, +} + +impl NavigationTarget { + pub fn file(&self) -> File { + self.file + } + + pub fn focus_range(&self) -> TextRange { + self.focus_range + } + + pub fn full_range(&self) -> TextRange { + self.full_range + } +} + +#[derive(Debug, Clone)] +pub struct NavigationTargets(smallvec::SmallVec<[NavigationTarget; 1]>); + +impl NavigationTargets { + fn single(target: NavigationTarget) -> Self { + Self(smallvec::smallvec![target]) + } + + fn empty() -> Self { + Self(smallvec::SmallVec::new()) + } + + fn iter(&self) -> std::slice::Iter<'_, NavigationTarget> { + self.0.iter() + } + + #[cfg(test)] + fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl IntoIterator for NavigationTargets { + type Item = NavigationTarget; + type IntoIter = smallvec::IntoIter<[NavigationTarget; 1]>; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a> IntoIterator for &'a NavigationTargets { + type Item = &'a NavigationTarget; + type IntoIter = std::slice::Iter<'a, NavigationTarget>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl FromIterator for NavigationTargets { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +pub trait HasNavigationTargets { + fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets; +} + +impl HasNavigationTargets for Type<'_> { + fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets { + match self { + Type::BoundMethod(method) => method.function(db).navigation_targets(db), + Type::FunctionLiteral(function) => function.navigation_targets(db), + Type::ModuleLiteral(module) => module.navigation_targets(db), + Type::Union(union) => union + .iter(db.upcast()) + .flat_map(|target| target.navigation_targets(db)) + .collect(), + Type::ClassLiteral(class) => class.navigation_targets(db), + Type::Instance(instance) => instance.navigation_targets(db), + Type::KnownInstance(instance) => instance.navigation_targets(db), + Type::SubclassOf(subclass_of_type) => match subclass_of_type.subclass_of() { + ClassBase::Class(class) => class.navigation_targets(db), + ClassBase::Dynamic(_) => NavigationTargets::empty(), + }, + + Type::StringLiteral(_) + | Type::BooleanLiteral(_) + | Type::LiteralString + | Type::IntLiteral(_) + | Type::BytesLiteral(_) + | Type::SliceLiteral(_) + | Type::MethodWrapper(_) + | Type::WrapperDescriptor(_) + | Type::PropertyInstance(_) + | Type::Tuple(_) => self.to_meta_type(db.upcast()).navigation_targets(db), + + Type::Intersection(intersection) => intersection.navigation_targets(db), + + Type::Dynamic(_) + | Type::Never + | Type::Callable(_) + | Type::AlwaysTruthy + | Type::AlwaysFalsy => NavigationTargets::empty(), + } + } +} + +impl HasNavigationTargets for FunctionType<'_> { + fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets { + let function_range = self.focus_range(db.upcast()); + NavigationTargets::single(NavigationTarget { + file: function_range.file(), + focus_range: function_range.range(), + full_range: self.full_range(db.upcast()).range(), + }) + } +} + +impl HasNavigationTargets for Class<'_> { + fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets { + let class_range = self.focus_range(db.upcast()); + NavigationTargets::single(NavigationTarget { + file: class_range.file(), + focus_range: class_range.range(), + full_range: self.full_range(db.upcast()).range(), + }) + } +} + +impl HasNavigationTargets for ClassLiteralType<'_> { + fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets { + self.class().navigation_targets(db) + } +} + +impl HasNavigationTargets for InstanceType<'_> { + fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets { + self.class().navigation_targets(db) + } +} + +impl HasNavigationTargets for ModuleLiteralType<'_> { + fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets { + let file = self.module(db).file(); + let source = source_text(db.upcast(), file); + + NavigationTargets::single(NavigationTarget { + file, + focus_range: TextRange::default(), + full_range: TextRange::up_to(source.text_len()), + }) + } +} + +impl HasNavigationTargets for KnownInstanceType<'_> { + fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets { + match self { + KnownInstanceType::TypeVar(var) => { + let definition = var.definition(db); + let full_range = definition.full_range(db.upcast()); + + NavigationTargets::single(NavigationTarget { + file: full_range.file(), + focus_range: definition.focus_range(db.upcast()).range(), + full_range: full_range.range(), + }) + } + + // TODO: Track the definition of `KnownInstance` and navigate to their definition. + _ => NavigationTargets::empty(), + } + } +} + +impl HasNavigationTargets for IntersectionType<'_> { + fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets { + // Only consider the positive elements because the negative elements are mainly from narrowing constraints. + let mut targets = self + .iter_positive(db.upcast()) + .filter(|ty| !ty.is_unknown()); + + let Some(first) = targets.next() else { + return NavigationTargets::empty(); + }; + + match targets.next() { + Some(_) => { + // If there are multiple types in the intersection, we can't navigate to a single one + // because the type is the intersection of all those types. + NavigationTargets::empty() + } + None => first.navigation_targets(db), + } + } +} diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index f96c5b19c2..fc05ced78e 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -1,6 +1,6 @@ use std::ops::Deref; -use ruff_db::files::File; +use ruff_db::files::{File, FileRange}; use ruff_db::parsed::ParsedModule; use ruff_python_ast as ast; use ruff_text_size::{Ranged, TextRange}; @@ -52,6 +52,14 @@ impl<'db> Definition<'db> { pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> { self.file_scope(db).to_scope_id(db, self.file(db)) } + + pub fn full_range(self, db: &'db dyn Db) -> FileRange { + FileRange::new(self.file(db), self.kind(db).full_range()) + } + + pub fn focus_range(self, db: &'db dyn Db) -> FileRange { + FileRange::new(self.file(db), self.kind(db).target_range()) + } } /// One or more [`Definition`]s. @@ -559,8 +567,6 @@ impl DefinitionKind<'_> { /// /// A definition target would mainly be the node representing the symbol being defined i.e., /// [`ast::ExprName`] or [`ast::Identifier`] but could also be other nodes. - /// - /// This is mainly used for logging and debugging purposes. pub(crate) fn target_range(&self) -> TextRange { match self { DefinitionKind::Import(import) => import.alias().range(), @@ -587,6 +593,33 @@ impl DefinitionKind<'_> { } } + /// Returns the [`TextRange`] of the entire definition. + pub(crate) fn full_range(&self) -> TextRange { + match self { + DefinitionKind::Import(import) => import.alias().range(), + DefinitionKind::ImportFrom(import) => import.alias().range(), + DefinitionKind::StarImport(import) => import.import().range(), + DefinitionKind::Function(function) => function.range(), + DefinitionKind::Class(class) => class.range(), + DefinitionKind::TypeAlias(type_alias) => type_alias.range(), + DefinitionKind::NamedExpression(named) => named.range(), + DefinitionKind::Assignment(assignment) => assignment.name().range(), + DefinitionKind::AnnotatedAssignment(assign) => assign.range(), + DefinitionKind::AugmentedAssignment(aug_assign) => aug_assign.range(), + DefinitionKind::For(for_stmt) => for_stmt.name().range(), + DefinitionKind::Comprehension(comp) => comp.target().range(), + DefinitionKind::VariadicPositionalParameter(parameter) => parameter.range(), + DefinitionKind::VariadicKeywordParameter(parameter) => parameter.range(), + DefinitionKind::Parameter(parameter) => parameter.parameter.range(), + DefinitionKind::WithItem(with_item) => with_item.name().range(), + DefinitionKind::MatchPattern(match_pattern) => match_pattern.identifier.range(), + DefinitionKind::ExceptHandler(handler) => handler.node().range(), + DefinitionKind::TypeVar(type_var) => type_var.range(), + DefinitionKind::ParamSpec(param_spec) => param_spec.range(), + DefinitionKind::TypeVarTuple(type_var_tuple) => type_var_tuple.range(), + } + } + pub(crate) fn category(&self, in_stub: bool) -> DefinitionCategory { match self { // functions, classes, and imports always bind, and we consider them declarations diff --git a/crates/red_knot_python_semantic/src/semantic_model.rs b/crates/red_knot_python_semantic/src/semantic_model.rs index 31d2eee775..01928c3ba5 100644 --- a/crates/red_knot_python_semantic/src/semantic_model.rs +++ b/crates/red_knot_python_semantic/src/semantic_model.rs @@ -160,6 +160,7 @@ impl_binding_has_ty!(ast::StmtFunctionDef); impl_binding_has_ty!(ast::StmtClassDef); impl_binding_has_ty!(ast::Parameter); impl_binding_has_ty!(ast::ParameterWithDefault); +impl_binding_has_ty!(ast::ExceptHandlerExceptHandler); impl HasType for ast::Alias { fn inferred_type<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 308cd66ee3..dc04d835a5 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -7,7 +7,7 @@ use call::{CallDunderError, CallError, CallErrorKind}; use context::InferContext; use diagnostic::{INVALID_CONTEXT_MANAGER, NOT_ITERABLE}; use itertools::EitherOrBoth; -use ruff_db::files::File; +use ruff_db::files::{File, FileRange}; use ruff_python_ast as ast; use ruff_python_ast::name::Name; use ruff_text_size::{Ranged, TextRange}; @@ -33,14 +33,16 @@ use crate::semantic_index::{imported_modules, semantic_index}; use crate::suppression::check_suppressions; use crate::symbol::{imported_symbol, Boundness, Symbol, SymbolAndQualifiers}; use crate::types::call::{Bindings, CallArgumentTypes}; -use crate::types::class_base::ClassBase; +pub use crate::types::class_base::ClassBase; use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION}; use crate::types::infer::infer_unpack_types; use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters}; use crate::{Db, FxOrderSet, Module, Program}; -pub(crate) use class::{Class, ClassLiteralType, InstanceType, KnownClass, KnownInstanceType}; +pub use class::Class; +pub(crate) use class::KnownClass; +pub use class::{ClassLiteralType, InstanceType, KnownInstanceType}; mod builder; mod call; @@ -3785,6 +3787,9 @@ pub struct TypeVarInstance<'db> { #[return_ref] name: ast::name::Name, + /// The type var's definition + pub definition: Definition<'db>, + /// The upper bound or constraint on the type of this TypeVar bound_or_constraints: Option>, @@ -4461,6 +4466,21 @@ impl<'db> FunctionType<'db> { Type::Callable(CallableType::new(db, self.signature(db).clone())) } + /// Returns the [`FileRange`] of the function's name. + pub fn focus_range(self, db: &dyn Db) -> FileRange { + FileRange::new( + self.body_scope(db).file(db), + self.body_scope(db).node(db).expect_function().name.range, + ) + } + + pub fn full_range(self, db: &dyn Db) -> FileRange { + FileRange::new( + self.body_scope(db).file(db), + self.body_scope(db).node(db).expect_function().range, + ) + } + /// Typed externally-visible signature for this function. /// /// This is the signature as seen by external callers, possibly modified by decorators and/or @@ -4622,7 +4642,7 @@ impl KnownFunction { pub struct BoundMethodType<'db> { /// The function that is being bound. Corresponds to the `__func__` attribute on a /// bound method object - pub(crate) function: FunctionType<'db>, + pub function: FunctionType<'db>, /// The instance on which this method has been called. Corresponds to the `__self__` /// attribute on a bound method object self_instance: Type<'db>, @@ -5332,6 +5352,10 @@ impl<'db> UnionType<'db> { Self::from_elements(db, self.elements(db).iter().filter(filter_fn)) } + pub fn iter(&self, db: &'db dyn Db) -> Iter> { + self.elements(db).iter() + } + pub(crate) fn map_with_boundness( self, db: &'db dyn Db, @@ -5735,6 +5759,10 @@ impl<'db> IntersectionType<'db> { qualifiers, } } + + pub fn iter_positive(&self, db: &'db dyn Db) -> impl Iterator> { + self.positive(db).iter().copied() + } } #[salsa::interned(debug)] diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index ae1fe40b21..e337a4352c 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -18,7 +18,7 @@ use crate::{ }; use indexmap::IndexSet; use itertools::Itertools as _; -use ruff_db::files::File; +use ruff_db::files::{File, FileRange}; use ruff_python_ast::{self as ast, PythonVersion}; use rustc_hash::FxHashSet; @@ -153,6 +153,15 @@ impl<'db> Class<'db> { self.body_scope(db).node(db).expect_class() } + /// Returns the file range of the class's name. + pub fn focus_range(self, db: &dyn Db) -> FileRange { + FileRange::new(self.file(db), self.node(db).name.range) + } + + pub fn full_range(self, db: &dyn Db) -> FileRange { + FileRange::new(self.file(db), self.node(db).range) + } + /// Return the types of the decorators on this class #[salsa::tracked(return_ref)] fn decorators(self, db: &'db dyn Db) -> Box<[Type<'db>]> { @@ -754,7 +763,7 @@ pub struct ClassLiteralType<'db> { } impl<'db> ClassLiteralType<'db> { - pub(crate) fn class(self) -> Class<'db> { + pub fn class(self) -> Class<'db> { self.class } @@ -780,7 +789,7 @@ pub struct InstanceType<'db> { } impl<'db> InstanceType<'db> { - pub(super) fn class(self) -> Class<'db> { + pub fn class(self) -> Class<'db> { self.class } diff --git a/crates/red_knot_python_semantic/src/types/class_base.rs b/crates/red_knot_python_semantic/src/types/class_base.rs index d789cad2e3..6309d6fb83 100644 --- a/crates/red_knot_python_semantic/src/types/class_base.rs +++ b/crates/red_knot_python_semantic/src/types/class_base.rs @@ -8,7 +8,7 @@ use itertools::Either; /// all types that would be invalid to have as a class base are /// transformed into [`ClassBase::unknown`] #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, salsa::Update)] -pub(crate) enum ClassBase<'db> { +pub enum ClassBase<'db> { Dynamic(DynamicType), Class(Class<'db>), } @@ -18,7 +18,7 @@ impl<'db> ClassBase<'db> { Self::Dynamic(DynamicType::Any) } - pub(crate) const fn unknown() -> Self { + pub const fn unknown() -> Self { Self::Dynamic(DynamicType::Unknown) } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index c7bf461bbb..3eefb1112c 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2016,6 +2016,7 @@ impl<'db> TypeInferenceBuilder<'db> { let ty = Type::KnownInstance(KnownInstanceType::TypeVar(TypeVarInstance::new( self.db(), name.id.clone(), + definition, bound_or_constraint, default_ty, ))); diff --git a/crates/red_knot_python_semantic/src/types/subclass_of.rs b/crates/red_knot_python_semantic/src/types/subclass_of.rs index a28ea010ab..ab36d9ff36 100644 --- a/crates/red_knot_python_semantic/src/types/subclass_of.rs +++ b/crates/red_knot_python_semantic/src/types/subclass_of.rs @@ -52,7 +52,7 @@ impl<'db> SubclassOfType<'db> { } /// Return the inner [`ClassBase`] value wrapped by this `SubclassOfType`. - pub(crate) const fn subclass_of(self) -> ClassBase<'db> { + pub const fn subclass_of(self) -> ClassBase<'db> { self.subclass_of } diff --git a/crates/red_knot_server/Cargo.toml b/crates/red_knot_server/Cargo.toml index d38ca95e3d..a96508b459 100644 --- a/crates/red_knot_server/Cargo.toml +++ b/crates/red_knot_server/Cargo.toml @@ -11,7 +11,9 @@ repository = { workspace = true } license = { workspace = true } [dependencies] +red_knot_ide = { workspace = true } red_knot_project = { workspace = true } +red_knot_python_semantic = { workspace = true } ruff_db = { workspace = true, features = ["os"] } ruff_notebook = { workspace = true } diff --git a/crates/red_knot_server/src/edit.rs b/crates/red_knot_server/src/document.rs similarity index 83% rename from crates/red_knot_server/src/edit.rs rename to crates/red_knot_server/src/document.rs index 0345f8aac5..7d1c896044 100644 --- a/crates/red_knot_server/src/edit.rs +++ b/crates/red_knot_server/src/document.rs @@ -1,12 +1,14 @@ //! Types and utilities for working with text, modifying source files, and `Ruff <-> LSP` type conversion. +mod location; mod notebook; mod range; mod text_document; +pub(crate) use location::ToLink; use lsp_types::{PositionEncodingKind, Url}; pub use notebook::NotebookDocument; -pub(crate) use range::{RangeExt, ToRangeExt}; +pub(crate) use range::{FileRangeExt, PositionExt, RangeExt, ToRangeExt}; pub(crate) use text_document::DocumentVersion; pub use text_document::TextDocument; @@ -53,17 +55,17 @@ impl std::fmt::Display for DocumentKey { } } -impl From for lsp_types::PositionEncodingKind { +impl From for PositionEncodingKind { fn from(value: PositionEncoding) -> Self { match value { - PositionEncoding::UTF8 => lsp_types::PositionEncodingKind::UTF8, - PositionEncoding::UTF16 => lsp_types::PositionEncodingKind::UTF16, - PositionEncoding::UTF32 => lsp_types::PositionEncodingKind::UTF32, + PositionEncoding::UTF8 => PositionEncodingKind::UTF8, + PositionEncoding::UTF16 => PositionEncodingKind::UTF16, + PositionEncoding::UTF32 => PositionEncodingKind::UTF32, } } } -impl TryFrom<&lsp_types::PositionEncodingKind> for PositionEncoding { +impl TryFrom<&PositionEncodingKind> for PositionEncoding { type Error = (); fn try_from(value: &PositionEncodingKind) -> Result { diff --git a/crates/red_knot_server/src/document/location.rs b/crates/red_knot_server/src/document/location.rs new file mode 100644 index 0000000000..c655b45f31 --- /dev/null +++ b/crates/red_knot_server/src/document/location.rs @@ -0,0 +1,58 @@ +use crate::document::{FileRangeExt, ToRangeExt}; +use crate::system::file_to_url; +use crate::PositionEncoding; +use lsp_types::Location; +use red_knot_ide::{Db, NavigationTarget}; +use ruff_db::files::FileRange; +use ruff_db::source::{line_index, source_text}; +use ruff_text_size::Ranged; + +pub(crate) trait ToLink { + fn to_location( + &self, + db: &dyn red_knot_ide::Db, + encoding: PositionEncoding, + ) -> Option; + + fn to_link( + &self, + db: &dyn red_knot_ide::Db, + src: Option, + encoding: PositionEncoding, + ) -> Option; +} + +impl ToLink for NavigationTarget { + fn to_location(&self, db: &dyn Db, encoding: PositionEncoding) -> Option { + FileRange::new(self.file(), self.focus_range()).to_location(db.upcast(), encoding) + } + + fn to_link( + &self, + db: &dyn Db, + src: Option, + encoding: PositionEncoding, + ) -> Option { + let file = self.file(); + let uri = file_to_url(db.upcast(), file)?; + let source = source_text(db.upcast(), file); + let index = line_index(db.upcast(), file); + + let target_range = self.full_range().to_range(&source, &index, encoding); + let selection_range = self.focus_range().to_range(&source, &index, encoding); + + let src = src.map(|src| { + let source = source_text(db.upcast(), src.file()); + let index = line_index(db.upcast(), src.file()); + + src.range().to_range(&source, &index, encoding) + }); + + Some(lsp_types::LocationLink { + target_uri: uri, + target_range, + target_selection_range: selection_range, + origin_selection_range: src, + }) + } +} diff --git a/crates/red_knot_server/src/edit/notebook.rs b/crates/red_knot_server/src/document/notebook.rs similarity index 100% rename from crates/red_knot_server/src/edit/notebook.rs rename to crates/red_knot_server/src/document/notebook.rs diff --git a/crates/red_knot_server/src/edit/range.rs b/crates/red_knot_server/src/document/range.rs similarity index 77% rename from crates/red_knot_server/src/edit/range.rs rename to crates/red_knot_server/src/document/range.rs index 9ccef9e67d..30fbc75410 100644 --- a/crates/red_knot_server/src/edit/range.rs +++ b/crates/red_knot_server/src/document/range.rs @@ -1,10 +1,17 @@ use super::notebook; use super::PositionEncoding; +use crate::system::file_to_url; + use lsp_types as types; +use lsp_types::Location; + +use red_knot_python_semantic::Db; +use ruff_db::files::FileRange; +use ruff_db::source::{line_index, source_text}; use ruff_notebook::NotebookIndex; use ruff_source_file::OneIndexed; use ruff_source_file::{LineIndex, SourceLocation}; -use ruff_text_size::{TextRange, TextSize}; +use ruff_text_size::{Ranged, TextRange, TextSize}; pub(crate) struct NotebookRange { pub(crate) cell: notebook::CellId, @@ -16,6 +23,10 @@ pub(crate) trait RangeExt { -> TextRange; } +pub(crate) trait PositionExt { + fn to_text_size(&self, text: &str, index: &LineIndex, encoding: PositionEncoding) -> TextSize; +} + pub(crate) trait ToRangeExt { fn to_range(&self, text: &str, index: &LineIndex, encoding: PositionEncoding) -> types::Range; fn to_notebook_range( @@ -31,6 +42,41 @@ fn u32_index_to_usize(index: u32) -> usize { usize::try_from(index).expect("u32 fits in usize") } +impl PositionExt for lsp_types::Position { + fn to_text_size(&self, text: &str, index: &LineIndex, encoding: PositionEncoding) -> TextSize { + let start_line = index.line_range( + OneIndexed::from_zero_indexed(u32_index_to_usize(self.line)), + text, + ); + + let start_column_offset = match encoding { + PositionEncoding::UTF8 => TextSize::new(self.character), + + PositionEncoding::UTF16 => { + // Fast path for ASCII only documents + if index.is_ascii() { + TextSize::new(self.character) + } else { + // UTF16 encodes characters either as one or two 16 bit words. + // The position in `range` is the 16-bit word offset from the start of the line (and not the character offset) + // UTF-16 with a text that may use variable-length characters. + utf8_column_offset(self.character, &text[start_line]) + } + } + PositionEncoding::UTF32 => { + // UTF-32 uses 4 bytes for each character. Meaning, the position in range is a character offset. + return index.offset( + OneIndexed::from_zero_indexed(u32_index_to_usize(self.line)), + OneIndexed::from_zero_indexed(u32_index_to_usize(self.character)), + text, + ); + } + }; + + start_line.start() + start_column_offset.clamp(TextSize::new(0), start_line.end()) + } +} + impl RangeExt for lsp_types::Range { fn to_text_range( &self, @@ -38,58 +84,9 @@ impl RangeExt for lsp_types::Range { index: &LineIndex, encoding: PositionEncoding, ) -> TextRange { - let start_line = index.line_range( - OneIndexed::from_zero_indexed(u32_index_to_usize(self.start.line)), - text, - ); - let end_line = index.line_range( - OneIndexed::from_zero_indexed(u32_index_to_usize(self.end.line)), - text, - ); - - let (start_column_offset, end_column_offset) = match encoding { - PositionEncoding::UTF8 => ( - TextSize::new(self.start.character), - TextSize::new(self.end.character), - ), - - PositionEncoding::UTF16 => { - // Fast path for ASCII only documents - if index.is_ascii() { - ( - TextSize::new(self.start.character), - TextSize::new(self.end.character), - ) - } else { - // UTF16 encodes characters either as one or two 16 bit words. - // The position in `range` is the 16-bit word offset from the start of the line (and not the character offset) - // UTF-16 with a text that may use variable-length characters. - ( - utf8_column_offset(self.start.character, &text[start_line]), - utf8_column_offset(self.end.character, &text[end_line]), - ) - } - } - PositionEncoding::UTF32 => { - // UTF-32 uses 4 bytes for each character. Meaning, the position in range is a character offset. - return TextRange::new( - index.offset( - OneIndexed::from_zero_indexed(u32_index_to_usize(self.start.line)), - OneIndexed::from_zero_indexed(u32_index_to_usize(self.start.character)), - text, - ), - index.offset( - OneIndexed::from_zero_indexed(u32_index_to_usize(self.end.line)), - OneIndexed::from_zero_indexed(u32_index_to_usize(self.end.character)), - text, - ), - ); - } - }; - TextRange::new( - start_line.start() + start_column_offset.clamp(TextSize::new(0), start_line.end()), - end_line.start() + end_column_offset.clamp(TextSize::new(0), end_line.end()), + self.start.to_text_size(text, index, encoding), + self.end.to_text_size(text, index, encoding), ) } } @@ -213,3 +210,19 @@ fn source_location_to_position(location: &SourceLocation) -> types::Position { .expect("character usize fits in u32"), } } + +pub(crate) trait FileRangeExt { + fn to_location(&self, db: &dyn Db, encoding: PositionEncoding) -> Option; +} + +impl FileRangeExt for FileRange { + fn to_location(&self, db: &dyn Db, encoding: PositionEncoding) -> Option { + let file = self.file(); + let uri = file_to_url(db, file)?; + let source = source_text(db.upcast(), file); + let line_index = line_index(db.upcast(), file); + + let range = self.range().to_range(&source, &line_index, encoding); + Some(Location { uri, range }) + } +} diff --git a/crates/red_knot_server/src/edit/text_document.rs b/crates/red_knot_server/src/document/text_document.rs similarity index 100% rename from crates/red_knot_server/src/edit/text_document.rs rename to crates/red_knot_server/src/document/text_document.rs diff --git a/crates/red_knot_server/src/lib.rs b/crates/red_knot_server/src/lib.rs index 7c149a9ae8..4efd797431 100644 --- a/crates/red_knot_server/src/lib.rs +++ b/crates/red_knot_server/src/lib.rs @@ -1,16 +1,15 @@ #![allow(dead_code)] +use crate::server::Server; use anyhow::Context; -pub use edit::{DocumentKey, NotebookDocument, PositionEncoding, TextDocument}; +pub use document::{DocumentKey, NotebookDocument, PositionEncoding, TextDocument}; pub use session::{ClientSettings, DocumentQuery, DocumentSnapshot, Session}; use std::num::NonZeroUsize; -use crate::server::Server; - #[macro_use] mod message; -mod edit; +mod document; mod logging; mod server; mod session; diff --git a/crates/red_knot_server/src/server.rs b/crates/red_knot_server/src/server.rs index 4a29644a87..2e0f88ead6 100644 --- a/crates/red_knot_server/src/server.rs +++ b/crates/red_knot_server/src/server.rs @@ -9,7 +9,7 @@ use lsp_server::Message; use lsp_types::{ ClientCapabilities, DiagnosticOptions, DiagnosticServerCapabilities, MessageType, ServerCapabilities, TextDocumentSyncCapability, TextDocumentSyncKind, TextDocumentSyncOptions, - Url, + TypeDefinitionProviderCapability, Url, }; use self::connection::{Connection, ConnectionInitializer}; @@ -220,6 +220,7 @@ impl Server { ..Default::default() }, )), + type_definition_provider: Some(TypeDefinitionProviderCapability::Simple(true)), ..Default::default() } } diff --git a/crates/red_knot_server/src/server/api.rs b/crates/red_knot_server/src/server/api.rs index 6a49f7758a..7b535729c9 100644 --- a/crates/red_knot_server/src/server/api.rs +++ b/crates/red_knot_server/src/server/api.rs @@ -26,6 +26,12 @@ pub(super) fn request<'a>(req: server::Request) -> Task<'a> { BackgroundSchedule::LatencySensitive, ) } + request::GotoTypeDefinitionRequestHandler::METHOD => { + background_request_task::( + req, + BackgroundSchedule::LatencySensitive, + ) + } method => { tracing::warn!("Received request {method} which does not have a handler"); return Task::nothing(); @@ -80,6 +86,10 @@ fn _local_request_task<'a, R: traits::SyncRequestHandler>( })) } +// TODO(micha): Calls to `db` could panic if the db gets mutated while this task is running. +// We should either wrap `R::run_with_snapshot` with a salsa catch cancellation handler or +// use `SemanticModel` instead of passing `db` which uses a Result for all it's methods +// that propagate cancellations. fn background_request_task<'a, R: traits::BackgroundDocumentRequestHandler>( req: server::Request, schedule: BackgroundSchedule, diff --git a/crates/red_knot_server/src/server/api/notifications/did_open_notebook.rs b/crates/red_knot_server/src/server/api/notifications/did_open_notebook.rs index 8fc2bf6936..ea355e7e0f 100644 --- a/crates/red_knot_server/src/server/api/notifications/did_open_notebook.rs +++ b/crates/red_knot_server/src/server/api/notifications/did_open_notebook.rs @@ -5,7 +5,7 @@ use lsp_types::DidOpenNotebookDocumentParams; use red_knot_project::watch::ChangeEvent; use ruff_db::Db; -use crate::edit::NotebookDocument; +use crate::document::NotebookDocument; use crate::server::api::traits::{NotificationHandler, SyncNotificationHandler}; use crate::server::api::LSPResult; use crate::server::client::{Notifier, Requester}; diff --git a/crates/red_knot_server/src/server/api/requests.rs b/crates/red_knot_server/src/server/api/requests.rs index 83e25fc6ed..26c45b7ea7 100644 --- a/crates/red_knot_server/src/server/api/requests.rs +++ b/crates/red_knot_server/src/server/api/requests.rs @@ -1,3 +1,5 @@ mod diagnostic; +mod goto_type_definition; pub(super) use diagnostic::DocumentDiagnosticRequestHandler; +pub(super) use goto_type_definition::GotoTypeDefinitionRequestHandler; diff --git a/crates/red_knot_server/src/server/api/requests/diagnostic.rs b/crates/red_knot_server/src/server/api/requests/diagnostic.rs index fc207fe7bd..89012f259e 100644 --- a/crates/red_knot_server/src/server/api/requests/diagnostic.rs +++ b/crates/red_knot_server/src/server/api/requests/diagnostic.rs @@ -7,7 +7,7 @@ use lsp_types::{ RelatedFullDocumentDiagnosticReport, Url, }; -use crate::edit::ToRangeExt; +use crate::document::ToRangeExt; use crate::server::api::traits::{BackgroundDocumentRequestHandler, RequestHandler}; use crate::server::{client::Notifier, Result}; use crate::session::DocumentSnapshot; diff --git a/crates/red_knot_server/src/server/api/requests/goto_type_definition.rs b/crates/red_knot_server/src/server/api/requests/goto_type_definition.rs new file mode 100644 index 0000000000..bb3a4e6e58 --- /dev/null +++ b/crates/red_knot_server/src/server/api/requests/goto_type_definition.rs @@ -0,0 +1,68 @@ +use std::borrow::Cow; + +use lsp_types::request::{GotoTypeDefinition, GotoTypeDefinitionParams}; +use lsp_types::{GotoDefinitionResponse, Url}; +use red_knot_ide::goto_type_definition; +use red_knot_project::ProjectDatabase; +use ruff_db::source::{line_index, source_text}; + +use crate::document::{PositionExt, ToLink}; +use crate::server::api::traits::{BackgroundDocumentRequestHandler, RequestHandler}; +use crate::server::client::Notifier; +use crate::DocumentSnapshot; + +pub(crate) struct GotoTypeDefinitionRequestHandler; + +impl RequestHandler for GotoTypeDefinitionRequestHandler { + type RequestType = GotoTypeDefinition; +} + +impl BackgroundDocumentRequestHandler for GotoTypeDefinitionRequestHandler { + fn document_url(params: &GotoTypeDefinitionParams) -> Cow { + Cow::Borrowed(¶ms.text_document_position_params.text_document.uri) + } + + fn run_with_snapshot( + snapshot: DocumentSnapshot, + db: ProjectDatabase, + _notifier: Notifier, + params: GotoTypeDefinitionParams, + ) -> crate::server::Result> { + let Some(file) = snapshot.file(&db) else { + tracing::debug!("Failed to resolve file for {:?}", params); + return Ok(None); + }; + + let source = source_text(&db, file); + let line_index = line_index(&db, file); + let offset = params.text_document_position_params.position.to_text_size( + &source, + &line_index, + snapshot.encoding(), + ); + + let Some(ranged) = goto_type_definition(&db, file, offset) else { + return Ok(None); + }; + + if snapshot + .resolved_client_capabilities() + .type_definition_link_support + { + let src = Some(ranged.range); + let links: Vec<_> = ranged + .into_iter() + .filter_map(|target| target.to_link(&db, src, snapshot.encoding())) + .collect(); + + Ok(Some(GotoDefinitionResponse::Link(links))) + } else { + let locations: Vec<_> = ranged + .into_iter() + .filter_map(|target| target.to_location(&db, snapshot.encoding())) + .collect(); + + Ok(Some(GotoDefinitionResponse::Array(locations))) + } + } +} diff --git a/crates/red_knot_server/src/session.rs b/crates/red_knot_server/src/session.rs index 6e370418be..470592e8bd 100644 --- a/crates/red_knot_server/src/session.rs +++ b/crates/red_knot_server/src/session.rs @@ -13,7 +13,7 @@ use ruff_db::files::{system_path_to_file, File}; use ruff_db::system::SystemPath; use ruff_db::Db; -use crate::edit::{DocumentKey, DocumentVersion, NotebookDocument}; +use crate::document::{DocumentKey, DocumentVersion, NotebookDocument}; use crate::system::{url_to_any_system_path, AnySystemPath, LSPSystem}; use crate::{PositionEncoding, TextDocument}; @@ -272,7 +272,7 @@ impl DocumentSnapshot { self.position_encoding } - pub(crate) fn file(&self, db: &ProjectDatabase) -> Option { + pub(crate) fn file(&self, db: &dyn Db) -> Option { match url_to_any_system_path(self.document_ref.file_url()).ok()? { AnySystemPath::System(path) => system_path_to_file(db, path).ok(), AnySystemPath::SystemVirtual(virtual_path) => db diff --git a/crates/red_knot_server/src/session/capabilities.rs b/crates/red_knot_server/src/session/capabilities.rs index 27d5d09ce7..ba27153c9b 100644 --- a/crates/red_knot_server/src/session/capabilities.rs +++ b/crates/red_knot_server/src/session/capabilities.rs @@ -8,6 +8,8 @@ pub(crate) struct ResolvedClientCapabilities { pub(crate) document_changes: bool, pub(crate) workspace_refresh: bool, pub(crate) pull_diagnostics: bool, + /// Whether `textDocument.typeDefinition.linkSupport` is `true` + pub(crate) type_definition_link_support: bool, } impl ResolvedClientCapabilities { @@ -36,6 +38,12 @@ impl ResolvedClientCapabilities { .and_then(|workspace_edit| workspace_edit.document_changes) .unwrap_or_default(); + let declaration_link_support = client_capabilities + .text_document + .as_ref() + .and_then(|document| document.type_definition?.link_support) + .unwrap_or_default(); + let workspace_refresh = true; // TODO(jane): Once the bug involving workspace.diagnostic(s) deserialization has been fixed, @@ -62,6 +70,7 @@ impl ResolvedClientCapabilities { document_changes, workspace_refresh, pull_diagnostics, + type_definition_link_support: declaration_link_support, } } } diff --git a/crates/red_knot_server/src/session/index.rs b/crates/red_knot_server/src/session/index.rs index 8b455ac392..2ef79c5dea 100644 --- a/crates/red_knot_server/src/session/index.rs +++ b/crates/red_knot_server/src/session/index.rs @@ -6,7 +6,7 @@ use lsp_types::Url; use rustc_hash::FxHashMap; use crate::{ - edit::{DocumentKey, DocumentVersion, NotebookDocument}, + document::{DocumentKey, DocumentVersion, NotebookDocument}, PositionEncoding, TextDocument, }; diff --git a/crates/red_knot_server/src/system.rs b/crates/red_knot_server/src/system.rs index 44c780c3f5..12d4057a72 100644 --- a/crates/red_knot_server/src/system.rs +++ b/crates/red_knot_server/src/system.rs @@ -3,8 +3,9 @@ use std::fmt::Display; use std::sync::Arc; use lsp_types::Url; - +use red_knot_python_semantic::Db; use ruff_db::file_revision::FileRevision; +use ruff_db::files::{File, FilePath}; use ruff_db::system::walk_directory::WalkDirectoryBuilder; use ruff_db::system::{ CaseSensitivity, DirectoryEntry, FileType, GlobError, Metadata, OsSystem, PatternError, Result, @@ -35,6 +36,16 @@ pub(crate) fn url_to_any_system_path(url: &Url) -> std::result::Result Option { + match file.path(db) { + FilePath::System(system) => Url::from_file_path(system.as_std_path()).ok(), + FilePath::SystemVirtual(path) => Url::parse(path.as_str()).ok(), + // TODO: Not yet supported, consider an approach similar to Sorbet's custom paths + // https://sorbet.org/docs/sorbet-uris + FilePath::Vendored(_) => None, + } +} + /// Represents either a [`SystemPath`] or a [`SystemVirtualPath`]. #[derive(Debug)] pub(crate) enum AnySystemPath { diff --git a/crates/ruff_db/src/files.rs b/crates/ruff_db/src/files.rs index 9b7cabe426..acb3faa35b 100644 --- a/crates/ruff_db/src/files.rs +++ b/crates/ruff_db/src/files.rs @@ -7,6 +7,7 @@ pub use file_root::{FileRoot, FileRootKind}; pub use path::FilePath; use ruff_notebook::{Notebook, NotebookError}; use ruff_python_ast::PySourceType; +use ruff_text_size::{Ranged, TextRange}; use salsa::plumbing::AsId; use salsa::{Durability, Setter}; @@ -510,6 +511,30 @@ impl fmt::Display for FileError { impl std::error::Error for FileError {} +/// Range with its corresponding file. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct FileRange { + file: File, + range: TextRange, +} + +impl FileRange { + pub const fn new(file: File, range: TextRange) -> Self { + Self { file, range } + } + + pub const fn file(&self) -> File { + self.file + } +} + +impl Ranged for FileRange { + #[inline] + fn range(&self) -> TextRange { + self.range + } +} + #[cfg(test)] mod tests { use crate::file_revision::FileRevision; diff --git a/crates/ruff_python_parser/src/lib.rs b/crates/ruff_python_parser/src/lib.rs index c2a60f544e..d63d250493 100644 --- a/crates/ruff_python_parser/src/lib.rs +++ b/crates/ruff_python_parser/src/lib.rs @@ -63,7 +63,6 @@ //! [lexical analysis]: https://en.wikipedia.org/wiki/Lexical_analysis //! [parsing]: https://en.wikipedia.org/wiki/Parsing //! [lexer]: crate::lexer - use std::iter::FusedIterator; use std::ops::Deref; @@ -558,6 +557,86 @@ impl Tokens { } } + /// Searches the token(s) at `offset`. + /// + /// Returns [`TokenAt::Between`] if `offset` points directly inbetween two tokens + /// (the left token ends at `offset` and the right token starts at `offset`). + /// + /// + /// ## Examples + /// + /// [Playground](https://play.ruff.rs/f3ad0a55-5931-4a13-96c7-b2b8bfdc9a2e?secondary=Tokens) + /// + /// ``` + /// # use ruff_python_ast::PySourceType; + /// # use ruff_python_parser::{Token, TokenAt, TokenKind}; + /// # use ruff_text_size::{Ranged, TextSize}; + /// + /// let source = r#" + /// def test(arg): + /// arg.call() + /// if True: + /// pass + /// print("true") + /// "#.trim(); + /// + /// let parsed = ruff_python_parser::parse_unchecked_source(source, PySourceType::Python); + /// let tokens = parsed.tokens(); + /// + /// let collect_tokens = |offset: TextSize| { + /// tokens.at_offset(offset).into_iter().map(|t| (t.kind(), &source[t.range()])).collect::>() + /// }; + /// + /// assert_eq!(collect_tokens(TextSize::new(4)), vec! [(TokenKind::Name, "test")]); + /// assert_eq!(collect_tokens(TextSize::new(6)), vec! [(TokenKind::Name, "test")]); + /// // between `arg` and `.` + /// assert_eq!(collect_tokens(TextSize::new(22)), vec! [(TokenKind::Name, "arg"), (TokenKind::Dot, ".")]); + /// assert_eq!(collect_tokens(TextSize::new(36)), vec! [(TokenKind::If, "if")]); + /// // Before the dedent token + /// assert_eq!(collect_tokens(TextSize::new(57)), vec! []); + /// ``` + pub fn at_offset(&self, offset: TextSize) -> TokenAt { + match self.binary_search_by_key(&offset, ruff_text_size::Ranged::start) { + // The token at `index` starts exactly at `offset. + // ```python + // object.attribute + // ^ OFFSET + // ``` + Ok(index) => { + let token = self[index]; + // `token` starts exactly at `offset`. Test if the offset is right between + // `token` and the previous token (if there's any) + if let Some(previous) = index.checked_sub(1).map(|idx| self[idx]) { + if previous.end() == offset { + return TokenAt::Between(previous, token); + } + } + + TokenAt::Single(token) + } + + // No token found that starts exactly at the given offset. But it's possible that + // the token starting before `offset` fully encloses `offset` (it's end range ends after `offset`). + // ```python + // object.attribute + // ^ OFFSET + // # or + // if True: + // print("test") + // ^ OFFSET + // ``` + Err(index) => { + if let Some(previous) = index.checked_sub(1).map(|idx| self[idx]) { + if previous.range().contains_inclusive(offset) { + return TokenAt::Single(previous); + } + } + + TokenAt::None + } + } + } + /// Returns a slice of tokens after the given [`TextSize`] offset. /// /// If the given offset is between two tokens, the returned slice will start from the following @@ -610,6 +689,39 @@ impl Deref for Tokens { } } +/// A token that encloses a given offset or ends exactly at it. +pub enum TokenAt { + /// There's no token at the given offset + None, + + /// There's a single token at the given offset. + Single(Token), + + /// The offset falls exactly between two tokens. E.g. `CURSOR` in `call(arguments)` is + /// positioned exactly between the `call` and `(` tokens. + Between(Token, Token), +} + +impl Iterator for TokenAt { + type Item = Token; + + fn next(&mut self) -> Option { + match *self { + TokenAt::None => None, + TokenAt::Single(token) => { + *self = TokenAt::None; + Some(token) + } + TokenAt::Between(first, second) => { + *self = TokenAt::Single(second); + Some(first) + } + } + } +} + +impl FusedIterator for TokenAt {} + impl From<&Tokens> for CommentRanges { fn from(tokens: &Tokens) -> Self { let mut ranges = vec![];