mirror of
https://github.com/apache/datafusion-sqlparser-rs.git
synced 2025-09-04 05:00:34 +00:00
Add configurable recursion limit to parser, to protect against stack overflows (#764)
This commit is contained in:
parent
2c20ec0be5
commit
79d0baad73
4 changed files with 302 additions and 31 deletions
231
src/parser.rs
231
src/parser.rs
|
@ -37,6 +37,7 @@ use crate::tokenizer::*;
|
|||
pub enum ParserError {
|
||||
TokenizerError(String),
|
||||
ParserError(String),
|
||||
RecursionLimitExceeded,
|
||||
}
|
||||
|
||||
// Use `Parser::expected` instead, if possible
|
||||
|
@ -55,6 +56,92 @@ macro_rules! return_ok_if_some {
|
|||
}};
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
/// Implemenation [`RecursionCounter`] if std is available
|
||||
mod recursion {
|
||||
use core::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::rc::Rc;
|
||||
|
||||
use super::ParserError;
|
||||
|
||||
/// Tracks remaining recursion depth. This value is decremented on
|
||||
/// each call to `try_decrease()`, when it reaches 0 an error will
|
||||
/// be returned.
|
||||
///
|
||||
/// Note: Uses an Rc and AtomicUsize in order to satisfy the Rust
|
||||
/// borrow checker so the automatic DepthGuard decrement a
|
||||
/// reference to the counter. The actual value is not modified
|
||||
/// concurrently
|
||||
pub(crate) struct RecursionCounter {
|
||||
remaining_depth: Rc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl RecursionCounter {
|
||||
/// Creates a [`RecursionCounter`] with the specified maximum
|
||||
/// depth
|
||||
pub fn new(remaining_depth: usize) -> Self {
|
||||
Self {
|
||||
remaining_depth: Rc::new(remaining_depth.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Decreases the remaining depth by 1.
|
||||
///
|
||||
/// Returns `Err` if the remaining depth falls to 0.
|
||||
///
|
||||
/// Returns a [`DepthGuard`] which will adds 1 to the
|
||||
/// remaining depth upon drop;
|
||||
pub fn try_decrease(&self) -> Result<DepthGuard, ParserError> {
|
||||
let old_value = self.remaining_depth.fetch_sub(1, Ordering::SeqCst);
|
||||
// ran out of space
|
||||
if old_value == 0 {
|
||||
Err(ParserError::RecursionLimitExceeded)
|
||||
} else {
|
||||
Ok(DepthGuard::new(Rc::clone(&self.remaining_depth)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Guard that increass the remaining depth by 1 on drop
|
||||
pub struct DepthGuard {
|
||||
remaining_depth: Rc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl DepthGuard {
|
||||
fn new(remaining_depth: Rc<AtomicUsize>) -> Self {
|
||||
Self { remaining_depth }
|
||||
}
|
||||
}
|
||||
impl Drop for DepthGuard {
|
||||
fn drop(&mut self) {
|
||||
self.remaining_depth.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
mod recursion {
|
||||
/// Implemenation [`RecursionCounter`] if std is NOT available (and does not
|
||||
/// guard against stack overflow).
|
||||
///
|
||||
/// Has the same API as the std RecursionCounter implementation
|
||||
/// but does not actually limit stack depth.
|
||||
pub(crate) struct RecursionCounter {}
|
||||
|
||||
impl RecursionCounter {
|
||||
pub fn new(_remaining_depth: usize) -> Self {
|
||||
Self {}
|
||||
}
|
||||
pub fn try_decrease(&self) -> Result<DepthGuard, super::ParserError> {
|
||||
Ok(DepthGuard {})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DepthGuard {}
|
||||
}
|
||||
|
||||
use recursion::RecursionCounter;
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum IsOptional {
|
||||
Optional,
|
||||
|
@ -96,6 +183,7 @@ impl fmt::Display for ParserError {
|
|||
match self {
|
||||
ParserError::TokenizerError(s) => s,
|
||||
ParserError::ParserError(s) => s,
|
||||
ParserError::RecursionLimitExceeded => "recursion limit exceeded",
|
||||
}
|
||||
)
|
||||
}
|
||||
|
@ -104,22 +192,78 @@ impl fmt::Display for ParserError {
|
|||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for ParserError {}
|
||||
|
||||
// By default, allow expressions up to this deep before erroring
|
||||
const DEFAULT_REMAINING_DEPTH: usize = 50;
|
||||
|
||||
pub struct Parser<'a> {
|
||||
tokens: Vec<TokenWithLocation>,
|
||||
/// The index of the first unprocessed token in `self.tokens`
|
||||
index: usize,
|
||||
/// The current dialect to use
|
||||
dialect: &'a dyn Dialect,
|
||||
/// ensure the stack does not overflow by limiting recusion depth
|
||||
recursion_counter: RecursionCounter,
|
||||
}
|
||||
|
||||
impl<'a> Parser<'a> {
|
||||
/// Parse the specified tokens
|
||||
/// To avoid breaking backwards compatibility, this function accepts
|
||||
/// bare tokens.
|
||||
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
|
||||
Parser::new_without_locations(tokens, dialect)
|
||||
/// Create a parser for a [`Dialect`]
|
||||
///
|
||||
/// See also [`Parser::parse_sql`]
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
|
||||
/// # fn main() -> Result<(), ParserError> {
|
||||
/// let dialect = GenericDialect{};
|
||||
/// let statements = Parser::new(&dialect)
|
||||
/// .try_with_sql("SELECT * FROM foo")?
|
||||
/// .parse_statements()?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(dialect: &'a dyn Dialect) -> Self {
|
||||
Self {
|
||||
tokens: vec![],
|
||||
index: 0,
|
||||
dialect,
|
||||
recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_without_locations(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
|
||||
/// Specify the maximum recursion limit while parsing.
|
||||
///
|
||||
///
|
||||
/// [`Parser`] prevents stack overflows by returning
|
||||
/// [`ParserError::RecursionLimitExceeded`] if the parser exceeds
|
||||
/// this depth while processing the query.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
|
||||
/// # fn main() -> Result<(), ParserError> {
|
||||
/// let dialect = GenericDialect{};
|
||||
/// let result = Parser::new(&dialect)
|
||||
/// .with_recursion_limit(1)
|
||||
/// .try_with_sql("SELECT * FROM foo WHERE (a OR (b OR (c OR d)))")?
|
||||
/// .parse_statements();
|
||||
/// assert_eq!(result, Err(ParserError::RecursionLimitExceeded));
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self {
|
||||
self.recursion_counter = RecursionCounter::new(recursion_limit);
|
||||
self
|
||||
}
|
||||
|
||||
/// Reset this parser to parse the specified token stream
|
||||
pub fn with_tokens_with_locations(mut self, tokens: Vec<TokenWithLocation>) -> Self {
|
||||
self.tokens = tokens;
|
||||
self.index = 0;
|
||||
self
|
||||
}
|
||||
|
||||
/// Reset this parser state to parse the specified tokens
|
||||
pub fn with_tokens(self, tokens: Vec<Token>) -> Self {
|
||||
// Put in dummy locations
|
||||
let tokens_with_locations: Vec<TokenWithLocation> = tokens
|
||||
.into_iter()
|
||||
|
@ -128,49 +272,84 @@ impl<'a> Parser<'a> {
|
|||
location: Location { line: 0, column: 0 },
|
||||
})
|
||||
.collect();
|
||||
Parser::new_with_locations(tokens_with_locations, dialect)
|
||||
self.with_tokens_with_locations(tokens_with_locations)
|
||||
}
|
||||
|
||||
/// Parse the specified tokens
|
||||
pub fn new_with_locations(tokens: Vec<TokenWithLocation>, dialect: &'a dyn Dialect) -> Self {
|
||||
Parser {
|
||||
tokens,
|
||||
index: 0,
|
||||
dialect,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
|
||||
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
|
||||
let mut tokenizer = Tokenizer::new(dialect, sql);
|
||||
/// Tokenize the sql string and sets this [`Parser`]'s state to
|
||||
/// parse the resulting tokens
|
||||
///
|
||||
/// Returns an error if there was an error tokenizing the SQL string.
|
||||
///
|
||||
/// See example on [`Parser::new()`] for an example
|
||||
pub fn try_with_sql(self, sql: &str) -> Result<Self, ParserError> {
|
||||
debug!("Parsing sql '{}'...", sql);
|
||||
let mut tokenizer = Tokenizer::new(self.dialect, sql);
|
||||
let tokens = tokenizer.tokenize()?;
|
||||
let mut parser = Parser::new(tokens, dialect);
|
||||
Ok(self.with_tokens(tokens))
|
||||
}
|
||||
|
||||
/// Parse potentially multiple statements
|
||||
///
|
||||
/// Example
|
||||
/// ```
|
||||
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
|
||||
/// # fn main() -> Result<(), ParserError> {
|
||||
/// let dialect = GenericDialect{};
|
||||
/// let statements = Parser::new(&dialect)
|
||||
/// // Parse a SQL string with 2 separate statements
|
||||
/// .try_with_sql("SELECT * FROM foo; SELECT * FROM bar;")?
|
||||
/// .parse_statements()?;
|
||||
/// assert_eq!(statements.len(), 2);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn parse_statements(&mut self) -> Result<Vec<Statement>, ParserError> {
|
||||
let mut stmts = Vec::new();
|
||||
let mut expecting_statement_delimiter = false;
|
||||
debug!("Parsing sql '{}'...", sql);
|
||||
loop {
|
||||
// ignore empty statements (between successive statement delimiters)
|
||||
while parser.consume_token(&Token::SemiColon) {
|
||||
while self.consume_token(&Token::SemiColon) {
|
||||
expecting_statement_delimiter = false;
|
||||
}
|
||||
|
||||
if parser.peek_token() == Token::EOF {
|
||||
if self.peek_token() == Token::EOF {
|
||||
break;
|
||||
}
|
||||
if expecting_statement_delimiter {
|
||||
return parser.expected("end of statement", parser.peek_token());
|
||||
return self.expected("end of statement", self.peek_token());
|
||||
}
|
||||
|
||||
let statement = parser.parse_statement()?;
|
||||
let statement = self.parse_statement()?;
|
||||
stmts.push(statement);
|
||||
expecting_statement_delimiter = true;
|
||||
}
|
||||
Ok(stmts)
|
||||
}
|
||||
|
||||
/// Convience method to parse a string with one or more SQL
|
||||
/// statements into produce an Abstract Syntax Tree (AST).
|
||||
///
|
||||
/// Example
|
||||
/// ```
|
||||
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
|
||||
/// # fn main() -> Result<(), ParserError> {
|
||||
/// let dialect = GenericDialect{};
|
||||
/// let statements = Parser::parse_sql(
|
||||
/// &dialect, "SELECT * FROM foo"
|
||||
/// )?;
|
||||
/// assert_eq!(statements.len(), 1);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
|
||||
Parser::new(dialect).try_with_sql(sql)?.parse_statements()
|
||||
}
|
||||
|
||||
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.),
|
||||
/// stopping before the statement separator, if any.
|
||||
pub fn parse_statement(&mut self) -> Result<Statement, ParserError> {
|
||||
let _guard = self.recursion_counter.try_decrease()?;
|
||||
|
||||
// allow the dialect to override statement parsing
|
||||
if let Some(statement) = self.dialect.parse_statement(self) {
|
||||
return statement;
|
||||
|
@ -364,6 +543,7 @@ impl<'a> Parser<'a> {
|
|||
|
||||
/// Parse a new expression
|
||||
pub fn parse_expr(&mut self) -> Result<Expr, ParserError> {
|
||||
let _guard = self.recursion_counter.try_decrease()?;
|
||||
self.parse_subexpr(0)
|
||||
}
|
||||
|
||||
|
@ -4512,6 +4692,7 @@ impl<'a> Parser<'a> {
|
|||
/// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't
|
||||
/// expect the initial keyword to be already consumed
|
||||
pub fn parse_query(&mut self) -> Result<Query, ParserError> {
|
||||
let _guard = self.recursion_counter.try_decrease()?;
|
||||
let with = if self.parse_keyword(Keyword::WITH) {
|
||||
Some(With {
|
||||
recursive: self.parse_keyword(Keyword::RECURSIVE),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue