Idea: add const generic peek_tokens method to parser (#1255)

This commit is contained in:
Joey Hain 2024-05-06 17:33:10 -07:00 committed by GitHub
parent 138722a7c9
commit a12a8882e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2771,9 +2771,7 @@ impl<'a> Parser<'a> {
let token = self.peek_token();
debug!("get_next_precedence() {:?}", token);
let token_0 = self.peek_nth_token(0);
let token_1 = self.peek_nth_token(1);
let token_2 = self.peek_nth_token(2);
let [token_0, token_1, token_2] = self.peek_tokens_with_location();
debug!("0: {token_0} 1: {token_1} 2: {token_2}");
match token.token {
Token::Word(w) if w.keyword == Keyword::OR => Ok(Self::OR_PREC),
@ -2865,6 +2863,56 @@ impl<'a> Parser<'a> {
self.peek_nth_token(0)
}
/// Returns the `N` next non-whitespace tokens that have not yet been
/// processed.
///
/// Example:
/// ```rust
/// # use sqlparser::dialect::GenericDialect;
/// # use sqlparser::parser::Parser;
/// # use sqlparser::keywords::Keyword;
/// # use sqlparser::tokenizer::{Token, Word};
/// let dialect = GenericDialect {};
/// let mut parser = Parser::new(&dialect).try_with_sql("ORDER BY foo, bar").unwrap();
///
/// // Note that Rust infers the number of tokens to peek based on the
/// // length of the slice pattern!
/// assert!(matches!(
/// parser.peek_tokens(),
/// [
/// Token::Word(Word { keyword: Keyword::ORDER, .. }),
/// Token::Word(Word { keyword: Keyword::BY, .. }),
/// ]
/// ));
/// ```
pub fn peek_tokens<const N: usize>(&self) -> [Token; N] {
self.peek_tokens_with_location()
.map(|with_loc| with_loc.token)
}
/// Returns the `N` next non-whitespace tokens with locations that have not
/// yet been processed.
///
/// See [`Self::peek_token`] for an example.
pub fn peek_tokens_with_location<const N: usize>(&self) -> [TokenWithLocation; N] {
let mut index = self.index;
core::array::from_fn(|_| loop {
let token = self.tokens.get(index);
index += 1;
if let Some(TokenWithLocation {
token: Token::Whitespace(_),
location: _,
}) = token
{
continue;
}
break token.cloned().unwrap_or(TokenWithLocation {
token: Token::EOF,
location: Location { line: 0, column: 0 },
});
})
}
/// Return nth non-whitespace token that has not yet been processed
pub fn peek_nth_token(&self, mut n: usize) -> TokenWithLocation {
let mut index = self.index;
@ -3159,8 +3207,7 @@ impl<'a> Parser<'a> {
}
// (,)
if self.options.trailing_commas
&& matches!(self.peek_nth_token(0).token, Token::Comma)
&& matches!(self.peek_nth_token(1).token, Token::RParen)
&& matches!(self.peek_tokens(), [Token::Comma, Token::RParen])
{
let _ = self.consume_token(&Token::Comma);
return Ok(vec![]);
@ -10479,6 +10526,51 @@ mod tests {
});
}
#[test]
fn test_peek_tokens() {
all_dialects().run_parser_method("SELECT foo AS bar FROM baz", |parser| {
assert!(matches!(
parser.peek_tokens(),
[Token::Word(Word {
keyword: Keyword::SELECT,
..
})]
));
assert!(matches!(
parser.peek_tokens(),
[
Token::Word(Word {
keyword: Keyword::SELECT,
..
}),
Token::Word(_),
Token::Word(Word {
keyword: Keyword::AS,
..
}),
]
));
for _ in 0..4 {
parser.next_token();
}
assert!(matches!(
parser.peek_tokens(),
[
Token::Word(Word {
keyword: Keyword::FROM,
..
}),
Token::Word(_),
Token::EOF,
Token::EOF,
]
))
})
}
#[cfg(test)]
mod test_parse_data_type {
use crate::ast::{