Remove lexer dependency from identifier_range (#5036)

## Summary

We run this quite a bit -- the new version is zero-allocation, though
it's not quite as nice as the lexer we have in the formatter.
This commit is contained in:
Charlie Marsh 2023-06-12 18:06:03 -04:00 committed by GitHub
parent ab11dd08df
commit 7e37d8916c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,5 +1,5 @@
use std::borrow::Cow;
use std::ops::Sub;
use std::ops::{Add, Sub};
use std::path::Path;
use itertools::Itertools;
@ -1071,6 +1071,21 @@ pub fn match_parens(start: TextSize, locator: &Locator) -> Option<TextRange> {
}
}
/// Return `true` if the given character is a valid identifier character.
fn is_identifier(c: char) -> bool {
c.is_alphanumeric() || c == '_'
}
#[derive(Debug)]
enum IdentifierState {
/// We're in a comment, awaiting the identifier at the given index.
InComment { index: usize },
/// We're looking for the identifier at the given index.
AwaitingIdentifier { index: usize },
/// We're in the identifier at the given index, starting at the given character.
InIdentifier { index: usize, start: TextSize },
}
/// Return the appropriate visual `Range` for any message that spans a `Stmt`.
/// Specifically, this method returns the range of a function or class name,
/// rather than that of the entire function or class body.
@ -1095,17 +1110,59 @@ pub fn identifier_range(stmt: &Stmt, locator: &Locator) -> TextRange {
TextRange::new(last_decorator.end(), range.end())
});
let contents = locator.slice(header_range);
// If the statement is an async function, we're looking for the third
// keyword-or-identifier (`foo` in `async def foo()`). Otherwise, it's the
// second keyword-or-identifier (`foo` in `def foo()` or `Foo` in `class Foo`).
let name_index = if stmt.is_async_function_def_stmt() {
2
} else {
1
};
let mut tokens =
lexer::lex_starts_at(contents, Mode::Module, header_range.start()).flatten();
tokens
.find_map(|(t, range)| t.is_name().then_some(range))
.unwrap_or_else(|| {
error!("Failed to find identifier for {:?}", stmt);
let mut state = IdentifierState::AwaitingIdentifier { index: 0 };
for (char_index, char) in locator.slice(header_range).char_indices() {
match state {
IdentifierState::InComment { index } => match char {
// Read until the end of the comment.
'\r' | '\n' => {
state = IdentifierState::AwaitingIdentifier { index };
}
_ => {}
},
IdentifierState::AwaitingIdentifier { index } => match char {
// Read until we hit an identifier.
'#' => {
state = IdentifierState::InComment { index };
}
c if is_identifier(c) => {
state = IdentifierState::InIdentifier {
index,
start: TextSize::try_from(char_index).unwrap(),
};
}
_ => {}
},
IdentifierState::InIdentifier { index, start } => {
// We've reached the end of the identifier.
if !is_identifier(char) {
if index == name_index {
// We've found the identifier we're looking for.
let end = TextSize::try_from(char_index).unwrap();
return TextRange::new(
header_range.start().add(start),
header_range.start().add(end),
);
}
header_range
})
// We're looking for a different identifier.
state = IdentifierState::AwaitingIdentifier { index: index + 1 };
}
}
}
}
error!("Failed to find identifier for {:?}", stmt);
header_range
}
_ => stmt.range(),
}
@ -1681,6 +1738,14 @@ y = 2
TextRange::new(TextSize::from(4), TextSize::from(5))
);
let contents = "async def f(): pass".trim();
let stmt = Stmt::parse(contents, "<filename>")?;
let locator = Locator::new(contents);
assert_eq!(
identifier_range(&stmt, &locator),
TextRange::new(TextSize::from(10), TextSize::from(11))
);
let contents = r#"
def \
f():
@ -1723,6 +1788,19 @@ class Class():
TextRange::new(TextSize::from(19), TextSize::from(24))
);
let contents = r#"
@decorator() # Comment
class Class():
pass
"#
.trim();
let stmt = Stmt::parse(contents, "<filename>")?;
let locator = Locator::new(contents);
assert_eq!(
identifier_range(&stmt, &locator),
TextRange::new(TextSize::from(30), TextSize::from(35))
);
let contents = r#"x = y + 1"#.trim();
let stmt = Stmt::parse(contents, "<filename>")?;
let locator = Locator::new(contents);