diff --git a/crates/ruff_python_ast/src/helpers.rs b/crates/ruff_python_ast/src/helpers.rs index c9dce0160c..a205e027e1 100644 --- a/crates/ruff_python_ast/src/helpers.rs +++ b/crates/ruff_python_ast/src/helpers.rs @@ -1,5 +1,5 @@ use std::borrow::Cow; -use std::ops::Sub; +use std::ops::{Add, Sub}; use std::path::Path; use itertools::Itertools; @@ -1071,6 +1071,21 @@ pub fn match_parens(start: TextSize, locator: &Locator) -> Option { } } +/// Return `true` if the given character is a valid identifier character. +fn is_identifier(c: char) -> bool { + c.is_alphanumeric() || c == '_' +} + +#[derive(Debug)] +enum IdentifierState { + /// We're in a comment, awaiting the identifier at the given index. + InComment { index: usize }, + /// We're looking for the identifier at the given index. + AwaitingIdentifier { index: usize }, + /// We're in the identifier at the given index, starting at the given character. + InIdentifier { index: usize, start: TextSize }, +} + /// Return the appropriate visual `Range` for any message that spans a `Stmt`. /// Specifically, this method returns the range of a function or class name, /// rather than that of the entire function or class body. @@ -1095,17 +1110,59 @@ pub fn identifier_range(stmt: &Stmt, locator: &Locator) -> TextRange { TextRange::new(last_decorator.end(), range.end()) }); - let contents = locator.slice(header_range); + // If the statement is an async function, we're looking for the third + // keyword-or-identifier (`foo` in `async def foo()`). Otherwise, it's the + // second keyword-or-identifier (`foo` in `def foo()` or `Foo` in `class Foo`). + let name_index = if stmt.is_async_function_def_stmt() { + 2 + } else { + 1 + }; - let mut tokens = - lexer::lex_starts_at(contents, Mode::Module, header_range.start()).flatten(); - tokens - .find_map(|(t, range)| t.is_name().then_some(range)) - .unwrap_or_else(|| { - error!("Failed to find identifier for {:?}", stmt); + let mut state = IdentifierState::AwaitingIdentifier { index: 0 }; + for (char_index, char) in locator.slice(header_range).char_indices() { + match state { + IdentifierState::InComment { index } => match char { + // Read until the end of the comment. + '\r' | '\n' => { + state = IdentifierState::AwaitingIdentifier { index }; + } + _ => {} + }, + IdentifierState::AwaitingIdentifier { index } => match char { + // Read until we hit an identifier. + '#' => { + state = IdentifierState::InComment { index }; + } + c if is_identifier(c) => { + state = IdentifierState::InIdentifier { + index, + start: TextSize::try_from(char_index).unwrap(), + }; + } + _ => {} + }, + IdentifierState::InIdentifier { index, start } => { + // We've reached the end of the identifier. + if !is_identifier(char) { + if index == name_index { + // We've found the identifier we're looking for. + let end = TextSize::try_from(char_index).unwrap(); + return TextRange::new( + header_range.start().add(start), + header_range.start().add(end), + ); + } - header_range - }) + // We're looking for a different identifier. + state = IdentifierState::AwaitingIdentifier { index: index + 1 }; + } + } + } + } + + error!("Failed to find identifier for {:?}", stmt); + header_range } _ => stmt.range(), } @@ -1681,6 +1738,14 @@ y = 2 TextRange::new(TextSize::from(4), TextSize::from(5)) ); + let contents = "async def f(): pass".trim(); + let stmt = Stmt::parse(contents, "")?; + let locator = Locator::new(contents); + assert_eq!( + identifier_range(&stmt, &locator), + TextRange::new(TextSize::from(10), TextSize::from(11)) + ); + let contents = r#" def \ f(): @@ -1723,6 +1788,19 @@ class Class(): TextRange::new(TextSize::from(19), TextSize::from(24)) ); + let contents = r#" +@decorator() # Comment +class Class(): + pass +"# + .trim(); + let stmt = Stmt::parse(contents, "")?; + let locator = Locator::new(contents); + assert_eq!( + identifier_range(&stmt, &locator), + TextRange::new(TextSize::from(30), TextSize::from(35)) + ); + let contents = r#"x = y + 1"#.trim(); let stmt = Stmt::parse(contents, "")?; let locator = Locator::new(contents);