Replace row/column based Location with byte-offsets. (#3931)

This commit is contained in:
Micha Reiser 2023-04-26 20:11:02 +02:00 committed by GitHub
parent ee91598835
commit cab65b25da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
418 changed files with 6203 additions and 7040 deletions

View file

@ -9,7 +9,7 @@ rust-version = { workspace = true }
ruff_formatter = { path = "../ruff_formatter" }
ruff_python_ast = { path = "../ruff_python_ast" }
ruff_rustpython = { path = "../ruff_rustpython" }
ruff_text_size = { path = "../ruff_text_size" }
ruff_text_size = { workspace = true }
anyhow = { workspace = true }
clap = { workspace = true }

View file

@ -1,9 +1,5 @@
use rustpython_parser::ast::Location;
use ruff_python_ast::newlines::StrExt;
use ruff_python_ast::source_code::Locator;
use ruff_python_ast::types::Range;
use ruff_text_size::TextRange;
use ruff_text_size::{TextLen, TextRange, TextSize};
/// Return `true` if the given string is a radix literal (e.g., `0b101`).
pub fn is_radix_literal(content: &str) -> bool {
@ -17,26 +13,22 @@ pub fn is_radix_literal(content: &str) -> bool {
/// Find the first token in the given range that satisfies the given predicate.
pub fn find_tok(
location: Location,
end_location: Location,
range: TextRange,
locator: &Locator,
f: impl Fn(rustpython_parser::Tok) -> bool,
) -> (Location, Location) {
for (start, tok, end) in rustpython_parser::lexer::lex_located(
locator.slice(Range::new(location, end_location)),
) -> TextRange {
for (tok, tok_range) in rustpython_parser::lexer::lex_located(
&locator.contents()[range],
rustpython_parser::Mode::Module,
location,
range.start(),
)
.flatten()
{
if f(tok) {
return (start, end);
return tok_range;
}
}
unreachable!(
"Failed to find token in range {:?}..{:?}",
location, end_location
)
unreachable!("Failed to find token in range {:?}", range)
}
/// Expand the range of a compound statement.
@ -44,19 +36,17 @@ pub fn find_tok(
/// `location` is the start of the compound statement (e.g., the `if` in `if x:`).
/// `end_location` is the end of the last statement in the body.
pub fn expand_indented_block(
location: Location,
end_location: Location,
location: TextSize,
end_location: TextSize,
locator: &Locator,
) -> (Location, Location) {
) -> TextRange {
let contents = locator.contents();
let start_index = locator.offset(location);
let end_index = locator.offset(end_location);
// Find the colon, which indicates the end of the header.
let mut nesting = 0;
let mut colon = None;
for (start, tok, _end) in rustpython_parser::lexer::lex_located(
&contents[TextRange::new(start_index, end_index)],
for (tok, tok_range) in rustpython_parser::lexer::lex_located(
&contents[TextRange::new(location, end_location)],
rustpython_parser::Mode::Module,
location,
)
@ -64,7 +54,7 @@ pub fn expand_indented_block(
{
match tok {
rustpython_parser::Tok::Colon if nesting == 0 => {
colon = Some(start);
colon = Some(tok_range.start());
break;
}
rustpython_parser::Tok::Lpar
@ -77,55 +67,68 @@ pub fn expand_indented_block(
}
}
let colon_location = colon.unwrap();
let colon_index = locator.offset(colon_location);
// From here, we have two options: simple statement or compound statement.
let indent = rustpython_parser::lexer::lex_located(
&contents[TextRange::new(colon_index, end_index)],
&contents[TextRange::new(colon_location, end_location)],
rustpython_parser::Mode::Module,
colon_location,
)
.flatten()
.find_map(|(start, tok, _end)| match tok {
rustpython_parser::Tok::Indent => Some(start),
.find_map(|(tok, range)| match tok {
rustpython_parser::Tok::Indent => Some(range.end()),
_ => None,
});
let Some(indent_location) = indent else {
let line_end = locator.line_end(end_location);
let Some(indent_end) = indent else {
// Simple statement: from the colon to the end of the line.
return (colon_location, Location::new(end_location.row() + 1, 0));
return TextRange::new(colon_location, line_end);
};
let indent_width = indent_end - locator.line_start(indent_end);
// Compound statement: from the colon to the end of the block.
let mut offset = 0;
for (index, line) in contents[usize::from(end_index)..]
.universal_newlines()
.skip(1)
.enumerate()
{
if line.is_empty() {
continue;
// For each line that follows, check that there's no content up to the expected indent.
let mut offset = TextSize::default();
let mut line_offset = TextSize::default();
// Issue, body goes to far.. it includes the whole try including the catch
let rest = &contents[usize::from(line_end)..];
for (relative_offset, c) in rest.char_indices() {
if line_offset < indent_width && !c.is_whitespace() {
break; // Found end of block
}
if line
.chars()
.take(indent_location.column())
.all(char::is_whitespace)
{
offset = index + 1;
} else {
break;
match c {
'\n' | '\r' => {
// Ignore empty lines
if line_offset > TextSize::from(0) {
offset = TextSize::try_from(relative_offset).unwrap() + TextSize::from(1);
}
line_offset = TextSize::from(0);
}
_ => {
line_offset += c.text_len();
}
}
}
let end_location = Location::new(end_location.row() + 1 + offset, 0);
(colon_location, end_location)
// Reached end of file
let end = if line_offset >= indent_width {
contents.text_len()
} else {
line_end + offset
};
TextRange::new(colon_location, end)
}
/// Return true if the `orelse` block of an `if` statement is an `elif` statement.
pub fn is_elif(orelse: &[rustpython_parser::ast::Stmt], locator: &Locator) -> bool {
if orelse.len() == 1 && matches!(orelse[0].node, rustpython_parser::ast::StmtKind::If { .. }) {
let contents = locator.after(orelse[0].location);
let contents = locator.after(orelse[0].start());
if contents.starts_with("elif") {
return true;
}

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,5 @@
use ruff_formatter::prelude::*;
use ruff_formatter::{write, Format};
use ruff_python_ast::types::Range;
use ruff_text_size::{TextRange, TextSize};
use crate::context::ASTFormatContext;
@ -68,25 +67,22 @@ pub fn statements(suite: &[Stmt]) -> Statements {
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Literal {
range: Range,
range: TextRange,
}
impl Format<ASTFormatContext<'_>> for Literal {
fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> {
let text = f.context().contents();
let locator = f.context().locator();
let start_index = locator.offset(self.range.location);
let end_index = locator.offset(self.range.end_location);
f.write_element(FormatElement::StaticTextSlice {
text,
range: TextRange::new(start_index, end_index),
range: self.range,
})
}
}
#[inline]
pub const fn literal(range: Range) -> Literal {
pub const fn literal(range: TextRange) -> Literal {
Literal { range }
}

View file

@ -4,7 +4,6 @@ use rustpython_parser::ast::Constant;
use ruff_formatter::prelude::*;
use ruff_formatter::{format_args, write};
use ruff_python_ast::types::Range;
use ruff_text_size::TextSize;
use crate::context::ASTFormatContext;
@ -39,7 +38,7 @@ fn format_name(
expr: &Expr,
_id: &str,
) -> FormatResult<()> {
write!(f, [literal(Range::from(expr))])?;
write!(f, [literal(expr.range())])?;
write!(f, [end_of_line_comments(expr)])?;
Ok(())
}
@ -97,62 +96,68 @@ fn format_tuple(
} else if !elts.is_empty() {
write!(
f,
[group(&format_with(|f| {
if expr.parentheses.is_if_expanded() {
write!(f, [if_group_breaks(&text("("))])?;
}
if matches!(
expr.parentheses,
Parenthesize::IfExpanded | Parenthesize::Always
) {
write!(
f,
[soft_block_indent(&format_with(|f| {
let magic_trailing_comma =
expr.trivia.iter().any(|c| c.kind.is_magic_trailing_comma());
let is_unbroken =
expr.location.row() == expr.end_location.unwrap().row();
if magic_trailing_comma {
write!(f, [expand_parent()])?;
}
for (i, elt) in elts.iter().enumerate() {
write!(f, [elt.format()])?;
if i < elts.len() - 1 {
write!(f, [text(",")])?;
write!(f, [soft_line_break_or_space()])?;
} else {
if magic_trailing_comma || is_unbroken {
write!(f, [if_group_breaks(&text(","))])?;
}
}
}
Ok(())
}))]
)?;
} else {
let magic_trailing_comma =
expr.trivia.iter().any(|c| c.kind.is_magic_trailing_comma());
let is_unbroken = expr.location.row() == expr.end_location.unwrap().row();
if magic_trailing_comma {
write!(f, [expand_parent()])?;
[group(&format_with(
|f: &mut Formatter<ASTFormatContext<'_>>| {
if expr.parentheses.is_if_expanded() {
write!(f, [if_group_breaks(&text("("))])?;
}
for (i, elt) in elts.iter().enumerate() {
write!(f, [elt.format()])?;
if i < elts.len() - 1 {
write!(f, [text(",")])?;
write!(f, [soft_line_break_or_space()])?;
} else {
if magic_trailing_comma || is_unbroken {
write!(f, [if_group_breaks(&text(","))])?;
if matches!(
expr.parentheses,
Parenthesize::IfExpanded | Parenthesize::Always
) {
write!(
f,
[soft_block_indent(&format_with(
|f: &mut Formatter<ASTFormatContext<'_>>| {
let magic_trailing_comma = expr
.trivia
.iter()
.any(|c| c.kind.is_magic_trailing_comma());
let is_unbroken =
!f.context().locator().contains_line_break(expr.range());
if magic_trailing_comma {
write!(f, [expand_parent()])?;
}
for (i, elt) in elts.iter().enumerate() {
write!(f, [elt.format()])?;
if i < elts.len() - 1 {
write!(f, [text(",")])?;
write!(f, [soft_line_break_or_space()])?;
} else {
if magic_trailing_comma || is_unbroken {
write!(f, [if_group_breaks(&text(","))])?;
}
}
}
Ok(())
}
))]
)?;
} else {
let magic_trailing_comma =
expr.trivia.iter().any(|c| c.kind.is_magic_trailing_comma());
let is_unbroken = !f.context().locator().contains_line_break(expr.range());
if magic_trailing_comma {
write!(f, [expand_parent()])?;
}
for (i, elt) in elts.iter().enumerate() {
write!(f, [elt.format()])?;
if i < elts.len() - 1 {
write!(f, [text(",")])?;
write!(f, [soft_line_break_or_space()])?;
} else {
if magic_trailing_comma || is_unbroken {
write!(f, [if_group_breaks(&text(","))])?;
}
}
}
}
if expr.parentheses.is_if_expanded() {
write!(f, [if_group_breaks(&text(")"))])?;
}
Ok(())
}
if expr.parentheses.is_if_expanded() {
write!(f, [if_group_breaks(&text(")"))])?;
}
Ok(())
}))]
))]
)?;
}
Ok(())
@ -577,7 +582,7 @@ fn format_joined_str(
expr: &Expr,
_values: &[Expr],
) -> FormatResult<()> {
write!(f, [literal(Range::from(expr))])?;
write!(f, [literal(expr.range())])?;
write!(f, [end_of_line_comments(expr)])?;
Ok(())
}
@ -598,11 +603,11 @@ fn format_constant(
write!(f, [text("False")])?;
}
}
Constant::Int(_) => write!(f, [int_literal(Range::from(expr))])?,
Constant::Float(_) => write!(f, [float_literal(Range::from(expr))])?,
Constant::Int(_) => write!(f, [int_literal(expr.range())])?,
Constant::Float(_) => write!(f, [float_literal(expr.range())])?,
Constant::Str(_) => write!(f, [string_literal(expr)])?,
Constant::Bytes(_) => write!(f, [string_literal(expr)])?,
Constant::Complex { .. } => write!(f, [complex_literal(Range::from(expr))])?,
Constant::Complex { .. } => write!(f, [complex_literal(expr.range())])?,
Constant::Tuple(_) => unreachable!("Constant::Tuple should be handled by format_tuple"),
}
write!(f, [end_of_line_comments(expr)])?;

View file

@ -1,8 +1,7 @@
use rustpython_parser::ast::Location;
use std::ops::{Add, Sub};
use ruff_formatter::prelude::*;
use ruff_formatter::{write, Format};
use ruff_python_ast::types::Range;
use ruff_text_size::{TextRange, TextSize};
use crate::context::ASTFormatContext;
@ -10,17 +9,14 @@ use crate::format::builders::literal;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
struct FloatAtom {
range: Range,
range: TextRange,
}
impl Format<ASTFormatContext<'_>> for FloatAtom {
fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> {
let locator = f.context().locator();
let contents = f.context().contents();
let start_index = locator.offset(self.range.location);
let end_index = locator.offset(self.range.end_location);
let content = &contents[TextRange::new(start_index, end_index)];
let content = &contents[self.range];
if let Some(dot_index) = content.find('.') {
let integer = &content[..dot_index];
let fractional = &content[dot_index + 1..];
@ -30,12 +26,11 @@ impl Format<ASTFormatContext<'_>> for FloatAtom {
} else {
write!(
f,
[literal(Range::new(
self.range.location,
Location::new(
self.range.location.row(),
self.range.location.column() + dot_index
),
[literal(TextRange::new(
self.range.start(),
self.range
.start()
.add(TextSize::try_from(dot_index).unwrap())
))]
)?;
}
@ -47,12 +42,11 @@ impl Format<ASTFormatContext<'_>> for FloatAtom {
} else {
write!(
f,
[literal(Range::new(
Location::new(
self.range.location.row(),
self.range.location.column() + dot_index + 1
),
self.range.end_location
[literal(TextRange::new(
self.range
.start()
.add(TextSize::try_from(dot_index + 1).unwrap()),
self.range.end()
))]
)?;
}
@ -65,35 +59,31 @@ impl Format<ASTFormatContext<'_>> for FloatAtom {
}
#[inline]
const fn float_atom(range: Range) -> FloatAtom {
const fn float_atom(range: TextRange) -> FloatAtom {
FloatAtom { range }
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct FloatLiteral {
range: Range,
range: TextRange,
}
impl Format<ASTFormatContext<'_>> for FloatLiteral {
fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> {
let locator = f.context().locator();
let contents = f.context().contents();
let start_index = locator.offset(self.range.location);
let end_index = locator.offset(self.range.end_location);
let content = &contents[TextRange::new(start_index, end_index)];
let content = &contents[self.range];
// Scientific notation
if let Some(exponent_index) = content.find('e').or_else(|| content.find('E')) {
// Write the base.
write!(
f,
[float_atom(Range::new(
self.range.location,
Location::new(
self.range.location.row(),
self.range.location.column() + exponent_index
),
[float_atom(TextRange::new(
self.range.start(),
self.range
.start()
.add(TextSize::try_from(exponent_index).unwrap())
))]
)?;
@ -103,12 +93,11 @@ impl Format<ASTFormatContext<'_>> for FloatLiteral {
let plus = content[exponent_index + 1..].starts_with('+');
write!(
f,
[literal(Range::new(
Location::new(
self.range.location.row(),
self.range.location.column() + exponent_index + 1 + usize::from(plus)
),
self.range.end_location
[literal(TextRange::new(
self.range
.start()
.add(TextSize::try_from(exponent_index + 1 + usize::from(plus)).unwrap()),
self.range.end()
))]
)?;
} else {
@ -120,24 +109,21 @@ impl Format<ASTFormatContext<'_>> for FloatLiteral {
}
#[inline]
pub const fn float_literal(range: Range) -> FloatLiteral {
pub const fn float_literal(range: TextRange) -> FloatLiteral {
FloatLiteral { range }
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct IntLiteral {
range: Range,
range: TextRange,
}
impl Format<ASTFormatContext<'_>> for IntLiteral {
fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> {
let locator = f.context().locator();
let contents = f.context().contents();
let start_index = locator.offset(self.range.location);
let end_index = locator.offset(self.range.end_location);
for prefix in ["0b", "0B", "0o", "0O", "0x", "0X"] {
let content = &contents[TextRange::new(start_index, end_index)];
let content = &contents[self.range];
if content.starts_with(prefix) {
// In each case, the prefix must be lowercase, while the suffix must be uppercase.
let prefix = &content[..prefix.len()];
@ -170,35 +156,28 @@ impl Format<ASTFormatContext<'_>> for IntLiteral {
}
#[inline]
pub const fn int_literal(range: Range) -> IntLiteral {
pub const fn int_literal(range: TextRange) -> IntLiteral {
IntLiteral { range }
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct ComplexLiteral {
range: Range,
range: TextRange,
}
impl Format<ASTFormatContext<'_>> for ComplexLiteral {
fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> {
let locator = f.context().locator();
let contents = f.context().contents();
let start_index = locator.offset(self.range.location);
let end_index = locator.offset(self.range.end_location);
let content = &contents[TextRange::new(start_index, end_index)];
let content = &contents[self.range];
if content.ends_with('j') {
write!(f, [literal(self.range)])?;
} else if content.ends_with('J') {
write!(
f,
[literal(Range::new(
self.range.location,
Location::new(
self.range.end_location.row(),
self.range.end_location.column() - 1
),
[literal(TextRange::new(
self.range.start(),
self.range.end().sub(TextSize::from(1))
))]
)?;
write!(f, [text("j")])?;
@ -211,6 +190,6 @@ impl Format<ASTFormatContext<'_>> for ComplexLiteral {
}
#[inline]
pub const fn complex_literal(range: Range) -> ComplexLiteral {
pub const fn complex_literal(range: TextRange) -> ComplexLiteral {
ComplexLiteral { range }
}

View file

@ -3,7 +3,6 @@ use rustpython_parser::{Mode, Tok};
use ruff_formatter::prelude::*;
use ruff_formatter::{write, Format};
use ruff_python_ast::str::{leading_quote, trailing_quote};
use ruff_python_ast::types::Range;
use ruff_text_size::{TextRange, TextSize};
use crate::context::ASTFormatContext;
@ -11,18 +10,15 @@ use crate::cst::Expr;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct StringLiteralPart {
range: Range,
range: TextRange,
}
impl Format<ASTFormatContext<'_>> for StringLiteralPart {
fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> {
let locator = f.context().locator();
let contents = f.context().contents();
let start_index = locator.offset(self.range.location);
let end_index = locator.offset(self.range.end_location);
// Extract leading and trailing quotes.
let contents = &contents[TextRange::new(start_index, end_index)];
let contents = &contents[self.range];
let leading_quote = leading_quote(contents).unwrap();
let trailing_quote = trailing_quote(contents).unwrap();
let body = &contents[leading_quote.len()..contents.len() - trailing_quote.len()];
@ -114,7 +110,7 @@ impl Format<ASTFormatContext<'_>> for StringLiteralPart {
}
#[inline]
pub const fn string_literal_part(range: Range) -> StringLiteralPart {
pub const fn string_literal_part(range: TextRange) -> StringLiteralPart {
StringLiteralPart { range }
}
@ -129,12 +125,12 @@ impl Format<ASTFormatContext<'_>> for StringLiteral<'_> {
// TODO(charlie): This tokenization needs to happen earlier, so that we can attach
// comments to individual string literals.
let contents = f.context().locator().slice(expr);
let elts = rustpython_parser::lexer::lex_located(contents, Mode::Module, expr.location)
let contents = f.context().locator().slice(expr.range());
let elts = rustpython_parser::lexer::lex_located(contents, Mode::Module, expr.start())
.flatten()
.filter_map(|(start, tok, end)| {
.filter_map(|(tok, range)| {
if matches!(tok, Tok::String { .. }) {
Some(Range::new(start, end))
Some(range)
} else {
None
}

View file

@ -28,7 +28,7 @@ pub fn fmt(contents: &str) -> Result<Formatted<ASTFormatContext>> {
let tokens: Vec<LexResult> = ruff_rustpython::tokenize(contents);
// Extract trivia.
let trivia = trivia::extract_trivia_tokens(&tokens);
let trivia = trivia::extract_trivia_tokens(&tokens, contents);
// Parse the AST.
let python_ast = ruff_rustpython::parse_program_tokens(tokens, "<filename>")?;

View file

@ -155,7 +155,7 @@ impl<'a> Visitor<'a> for ParenthesesNormalizer<'_> {
},
) {
// TODO(charlie): Encode this in the AST via separate node types.
if !is_radix_literal(self.locator.slice(&**value)) {
if !is_radix_literal(self.locator.slice(value.range())) {
value.parentheses = Parenthesize::Always;
}
}

View file

@ -1,16 +1,15 @@
use ruff_text_size::{TextRange, TextSize};
use rustc_hash::FxHashMap;
use rustpython_parser::ast::Location;
use rustpython_parser::lexer::LexResult;
use rustpython_parser::Tok;
use ruff_python_ast::types::Range;
use std::ops::Add;
use crate::cst::{
Alias, Arg, Body, BoolOp, CmpOp, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Keyword,
Operator, Pattern, PatternKind, SliceIndex, SliceIndexKind, Stmt, StmtKind, UnaryOp,
};
#[derive(Clone, Debug)]
#[derive(Clone, Copy, Debug)]
pub enum Node<'a> {
Alias(&'a Alias),
Arg(&'a Arg),
@ -48,41 +47,41 @@ impl Node<'_> {
}
}
pub fn location(&self) -> Location {
pub fn start(&self) -> TextSize {
match self {
Node::Alias(node) => node.location,
Node::Arg(node) => node.location,
Node::Body(node) => node.location,
Node::BoolOp(node) => node.location,
Node::CmpOp(node) => node.location,
Node::Excepthandler(node) => node.location,
Node::Expr(node) => node.location,
Node::Keyword(node) => node.location,
Node::Alias(node) => node.start(),
Node::Arg(node) => node.start(),
Node::Body(node) => node.start(),
Node::BoolOp(node) => node.start(),
Node::CmpOp(node) => node.start(),
Node::Excepthandler(node) => node.start(),
Node::Expr(node) => node.start(),
Node::Keyword(node) => node.start(),
Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"),
Node::Operator(node) => node.location,
Node::Pattern(node) => node.location,
Node::SliceIndex(node) => node.location,
Node::Stmt(node) => node.location,
Node::UnaryOp(node) => node.location,
Node::Operator(node) => node.start(),
Node::Pattern(node) => node.start(),
Node::SliceIndex(node) => node.start(),
Node::Stmt(node) => node.start(),
Node::UnaryOp(node) => node.start(),
}
}
pub fn end_location(&self) -> Location {
pub fn end(&self) -> TextSize {
match self {
Node::Alias(node) => node.end_location.unwrap(),
Node::Arg(node) => node.end_location.unwrap(),
Node::Body(node) => node.end_location.unwrap(),
Node::BoolOp(node) => node.end_location.unwrap(),
Node::CmpOp(node) => node.end_location.unwrap(),
Node::Excepthandler(node) => node.end_location.unwrap(),
Node::Expr(node) => node.end_location.unwrap(),
Node::Keyword(node) => node.end_location.unwrap(),
Node::Alias(node) => node.end(),
Node::Arg(node) => node.end(),
Node::Body(node) => node.end(),
Node::BoolOp(node) => node.end(),
Node::CmpOp(node) => node.end(),
Node::Excepthandler(node) => node.end(),
Node::Expr(node) => node.end(),
Node::Keyword(node) => node.end(),
Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"),
Node::Operator(node) => node.end_location.unwrap(),
Node::Pattern(node) => node.end_location.unwrap(),
Node::SliceIndex(node) => node.end_location.unwrap(),
Node::Stmt(node) => node.end_location.unwrap(),
Node::UnaryOp(node) => node.end_location.unwrap(),
Node::Operator(node) => node.end(),
Node::Pattern(node) => node.end(),
Node::SliceIndex(node) => node.end(),
Node::Stmt(node) => node.end(),
Node::UnaryOp(node) => node.end(),
}
}
}
@ -98,11 +97,20 @@ pub enum TriviaTokenKind {
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TriviaToken {
pub start: Location,
pub end: Location,
pub range: TextRange,
pub kind: TriviaTokenKind,
}
impl TriviaToken {
pub const fn start(&self) -> TextSize {
self.range.start()
}
pub const fn end(&self) -> TextSize {
self.range.end()
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, is_macro::Is)]
pub enum TriviaKind {
/// A Comment that is separated by at least one line break from the
@ -115,7 +123,7 @@ pub enum TriviaKind {
/// # This is an own-line comment.
/// b = 2
/// ```
OwnLineComment(Range),
OwnLineComment(TextRange),
/// A comment that is on the same line as the preceding token.
///
/// # Examples
@ -126,7 +134,7 @@ pub enum TriviaKind {
/// a = 1 # This is an end-of-line comment.
/// b = 2
/// ```
EndOfLineComment(Range),
EndOfLineComment(TextRange),
MagicTrailingComma,
EmptyLine,
Parentheses,
@ -167,11 +175,11 @@ impl Trivia {
relationship,
},
TriviaTokenKind::OwnLineComment => Self {
kind: TriviaKind::OwnLineComment(Range::new(token.start, token.end)),
kind: TriviaKind::OwnLineComment(token.range),
relationship,
},
TriviaTokenKind::EndOfLineComment => Self {
kind: TriviaKind::EndOfLineComment(Range::new(token.start, token.end)),
kind: TriviaKind::EndOfLineComment(token.range),
relationship,
},
TriviaTokenKind::Parentheses => Self {
@ -182,30 +190,53 @@ impl Trivia {
}
}
pub fn extract_trivia_tokens(lxr: &[LexResult]) -> Vec<TriviaToken> {
pub fn extract_trivia_tokens(lxr: &[LexResult], text: &str) -> Vec<TriviaToken> {
let mut tokens = vec![];
let mut prev_tok: Option<(&Location, &Tok, &Location)> = None;
let mut prev_non_newline_tok: Option<(&Location, &Tok, &Location)> = None;
let mut prev_semantic_tok: Option<(&Location, &Tok, &Location)> = None;
let mut prev_end = TextSize::default();
let mut prev_tok: Option<(&Tok, TextRange)> = None;
let mut prev_semantic_tok: Option<(&Tok, TextRange)> = None;
let mut parens = vec![];
for (start, tok, end) in lxr.iter().flatten() {
for (tok, range) in lxr.iter().flatten() {
// Add empty lines.
if let Some((.., prev)) = prev_non_newline_tok {
for row in prev.row() + 1..start.row() {
let trivia = &text[TextRange::new(prev_end, range.start())];
let bytes = trivia.as_bytes();
let mut bytes_iter = bytes.iter().enumerate();
let mut after_new_line =
matches!(prev_tok, Some((Tok::Newline | Tok::NonLogicalNewline, _)));
while let Some((index, byte)) = bytes_iter.next() {
let len = match byte {
b'\r' if bytes.get(index + 1) == Some(&b'\n') => {
bytes_iter.next();
TextSize::from(2)
}
b'\n' | b'\r' => TextSize::from(1),
_ => {
// Must be whitespace or the parser would generate a token
continue;
}
};
if after_new_line {
let new_line_start = prev_end.add(TextSize::try_from(index).unwrap());
tokens.push(TriviaToken {
start: Location::new(row, 0),
end: Location::new(row + 1, 0),
range: TextRange::new(new_line_start, new_line_start.add(len)),
kind: TriviaTokenKind::EmptyLine,
});
} else {
after_new_line = true;
}
}
// Add comments.
if let Tok::Comment(_) = tok {
tokens.push(TriviaToken {
start: *start,
end: *end,
kind: if prev_non_newline_tok.map_or(true, |(prev, ..)| prev.row() < start.row()) {
range: *range,
// Used to use prev_non-newline_tok
kind: if after_new_line || prev_tok.is_none() {
TriviaTokenKind::OwnLineComment
} else {
TriviaTokenKind::EndOfLineComment
@ -218,11 +249,10 @@ pub fn extract_trivia_tokens(lxr: &[LexResult]) -> Vec<TriviaToken> {
tok,
Tok::Rpar | Tok::Rsqb | Tok::Rbrace | Tok::Equal | Tok::Newline
) {
if let Some((prev_start, prev_tok, prev_end)) = prev_semantic_tok {
if let Some((prev_tok, prev_range)) = prev_semantic_tok {
if prev_tok == &Tok::Comma {
tokens.push(TriviaToken {
start: *prev_start,
end: *prev_end,
range: prev_range,
kind: TriviaTokenKind::MagicTrailingComma,
});
}
@ -230,7 +260,7 @@ pub fn extract_trivia_tokens(lxr: &[LexResult]) -> Vec<TriviaToken> {
}
if matches!(tok, Tok::Lpar) {
if prev_tok.map_or(true, |(_, prev_tok, _)| {
if prev_tok.map_or(true, |(prev_tok, _)| {
!matches!(
prev_tok,
Tok::Name { .. }
@ -240,40 +270,36 @@ pub fn extract_trivia_tokens(lxr: &[LexResult]) -> Vec<TriviaToken> {
| Tok::String { .. }
)
}) {
parens.push((start, true));
parens.push((range.start(), true));
} else {
parens.push((start, false));
parens.push((range.start(), false));
}
} else if matches!(tok, Tok::Rpar) {
let (start, explicit) = parens.pop().unwrap();
if explicit {
tokens.push(TriviaToken {
start: *start,
end: *end,
range: TextRange::new(start, range.end()),
kind: TriviaTokenKind::Parentheses,
});
}
}
prev_tok = Some((start, tok, end));
// Track the most recent non-whitespace token.
if !matches!(tok, Tok::Newline | Tok::NonLogicalNewline) {
prev_non_newline_tok = Some((start, tok, end));
}
prev_tok = Some((tok, *range));
// Track the most recent semantic token.
if !matches!(
tok,
Tok::Newline | Tok::NonLogicalNewline | Tok::Comment(..)
) {
prev_semantic_tok = Some((start, tok, end));
prev_semantic_tok = Some((tok, *range));
}
prev_end = range.end();
}
tokens
}
fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
fn sorted_child_nodes_inner<'a>(node: Node<'a>, result: &mut Vec<Node<'a>>) {
match node {
Node::Mod(nodes) => {
for stmt in nodes.iter() {
@ -281,9 +307,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
}
}
Node::Body(body) => {
for stmt in &body.node {
result.push(Node::Stmt(stmt));
}
result.extend(body.iter().map(Node::Stmt));
}
Node::Stmt(stmt) => match &stmt.node {
StmtKind::Return { value } => {
@ -734,15 +758,16 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
}
}
pub fn sorted_child_nodes<'a>(node: &Node<'a>) -> Vec<Node<'a>> {
pub fn sorted_child_nodes(node: Node) -> Vec<Node> {
let mut result = Vec::new();
sorted_child_nodes_inner(node, &mut result);
result
}
pub fn decorate_token<'a>(
token: &TriviaToken,
node: &Node<'a>,
node: Node<'a>,
enclosing_node: Option<Node<'a>>,
enclosed_node: Option<Node<'a>>,
cache: &mut FxHashMap<usize, Vec<Node<'a>>>,
@ -765,51 +790,45 @@ pub fn decorate_token<'a>(
while left < right {
let middle = (left + right) / 2;
let child = &child_nodes[middle];
let start = child.location();
let end = child.end_location();
let child = child_nodes[middle];
let start = child.start();
let end = child.end();
if let Some(existing) = &enclosed_node {
// Special-case: if we're dealing with a statement that's a single expression,
// we want to treat the expression as the enclosed node.
let existing_start = existing.location();
let existing_end = existing.end_location();
let existing_start = existing.start();
let existing_end = existing.end();
if start == existing_start && end == existing_end {
enclosed_node = Some(child.clone());
enclosed_node = Some(child);
}
} else {
if token.start <= start && token.end >= end {
enclosed_node = Some(child.clone());
if token.start() <= start && token.end() >= end {
enclosed_node = Some(child);
}
}
// The comment is completely contained by this child node.
if token.start >= start && token.end <= end {
return decorate_token(
token,
&child.clone(),
Some(child.clone()),
enclosed_node,
cache,
);
if token.start() >= start && token.end() <= end {
return decorate_token(token, child, Some(child), enclosed_node, cache);
}
if end <= token.start {
if end <= token.start() {
// This child node falls completely before the comment.
// Because we will never consider this node or any nodes
// before it again, this node must be the closest preceding
// node we have encountered so far.
preceding_node = Some(child.clone());
preceding_node = Some(child);
left = middle + 1;
continue;
}
if token.end <= start {
if token.end() <= start {
// This child node falls completely after the comment.
// Because we will never consider this node or any nodes after
// it again, this node must be the closest following node we
// have encountered so far.
following_node = Some(child.clone());
following_node = Some(child);
right = middle;
continue;
}
@ -944,7 +963,7 @@ pub fn decorate_trivia(tokens: Vec<TriviaToken>, python_ast: &[Stmt]) -> TriviaI
let mut cache = FxHashMap::default();
for token in &tokens {
let (preceding_node, following_node, enclosing_node, enclosed_node) =
decorate_token(token, &Node::Mod(python_ast), None, None, &mut cache);
decorate_token(token, Node::Mod(python_ast), None, None, &mut cache);
stack.push((
preceding_node,