Add configurable recursion limit to parser, to protect against stack overflows (#764)

This commit is contained in:
Andrew Lamb 2022-12-28 08:29:51 -05:00 committed by GitHub
parent 2c20ec0be5
commit 79d0baad73
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 302 additions and 31 deletions

View file

@ -12,11 +12,15 @@
//! SQL Parser for Rust //! SQL Parser for Rust
//! //!
//! Example code:
//!
//! This crate provides an ANSI:SQL 2011 lexer and parser that can parse SQL //! This crate provides an ANSI:SQL 2011 lexer and parser that can parse SQL
//! into an Abstract Syntax Tree (AST). //! into an Abstract Syntax Tree (AST).
//! //!
//! See [`Parser::parse_sql`](crate::parser::Parser::parse_sql) and
//! [`Parser::new`](crate::parser::Parser::new) for the Parsing API
//! and the [`ast`](crate::ast) crate for the AST structure.
//!
//! Example:
//!
//! ``` //! ```
//! use sqlparser::dialect::GenericDialect; //! use sqlparser::dialect::GenericDialect;
//! use sqlparser::parser::Parser; //! use sqlparser::parser::Parser;

View file

@ -37,6 +37,7 @@ use crate::tokenizer::*;
pub enum ParserError { pub enum ParserError {
TokenizerError(String), TokenizerError(String),
ParserError(String), ParserError(String),
RecursionLimitExceeded,
} }
// Use `Parser::expected` instead, if possible // Use `Parser::expected` instead, if possible
@ -55,6 +56,92 @@ macro_rules! return_ok_if_some {
}}; }};
} }
#[cfg(feature = "std")]
/// Implemenation [`RecursionCounter`] if std is available
mod recursion {
use core::sync::atomic::{AtomicUsize, Ordering};
use std::rc::Rc;
use super::ParserError;
/// Tracks remaining recursion depth. This value is decremented on
/// each call to `try_decrease()`, when it reaches 0 an error will
/// be returned.
///
/// Note: Uses an Rc and AtomicUsize in order to satisfy the Rust
/// borrow checker so the automatic DepthGuard decrement a
/// reference to the counter. The actual value is not modified
/// concurrently
pub(crate) struct RecursionCounter {
remaining_depth: Rc<AtomicUsize>,
}
impl RecursionCounter {
/// Creates a [`RecursionCounter`] with the specified maximum
/// depth
pub fn new(remaining_depth: usize) -> Self {
Self {
remaining_depth: Rc::new(remaining_depth.into()),
}
}
/// Decreases the remaining depth by 1.
///
/// Returns `Err` if the remaining depth falls to 0.
///
/// Returns a [`DepthGuard`] which will adds 1 to the
/// remaining depth upon drop;
pub fn try_decrease(&self) -> Result<DepthGuard, ParserError> {
let old_value = self.remaining_depth.fetch_sub(1, Ordering::SeqCst);
// ran out of space
if old_value == 0 {
Err(ParserError::RecursionLimitExceeded)
} else {
Ok(DepthGuard::new(Rc::clone(&self.remaining_depth)))
}
}
}
/// Guard that increass the remaining depth by 1 on drop
pub struct DepthGuard {
remaining_depth: Rc<AtomicUsize>,
}
impl DepthGuard {
fn new(remaining_depth: Rc<AtomicUsize>) -> Self {
Self { remaining_depth }
}
}
impl Drop for DepthGuard {
fn drop(&mut self) {
self.remaining_depth.fetch_add(1, Ordering::SeqCst);
}
}
}
#[cfg(not(feature = "std"))]
mod recursion {
/// Implemenation [`RecursionCounter`] if std is NOT available (and does not
/// guard against stack overflow).
///
/// Has the same API as the std RecursionCounter implementation
/// but does not actually limit stack depth.
pub(crate) struct RecursionCounter {}
impl RecursionCounter {
pub fn new(_remaining_depth: usize) -> Self {
Self {}
}
pub fn try_decrease(&self) -> Result<DepthGuard, super::ParserError> {
Ok(DepthGuard {})
}
}
pub struct DepthGuard {}
}
use recursion::RecursionCounter;
#[derive(PartialEq, Eq)] #[derive(PartialEq, Eq)]
pub enum IsOptional { pub enum IsOptional {
Optional, Optional,
@ -96,6 +183,7 @@ impl fmt::Display for ParserError {
match self { match self {
ParserError::TokenizerError(s) => s, ParserError::TokenizerError(s) => s,
ParserError::ParserError(s) => s, ParserError::ParserError(s) => s,
ParserError::RecursionLimitExceeded => "recursion limit exceeded",
} }
) )
} }
@ -104,22 +192,78 @@ impl fmt::Display for ParserError {
#[cfg(feature = "std")] #[cfg(feature = "std")]
impl std::error::Error for ParserError {} impl std::error::Error for ParserError {}
// By default, allow expressions up to this deep before erroring
const DEFAULT_REMAINING_DEPTH: usize = 50;
pub struct Parser<'a> { pub struct Parser<'a> {
tokens: Vec<TokenWithLocation>, tokens: Vec<TokenWithLocation>,
/// The index of the first unprocessed token in `self.tokens` /// The index of the first unprocessed token in `self.tokens`
index: usize, index: usize,
/// The current dialect to use
dialect: &'a dyn Dialect, dialect: &'a dyn Dialect,
/// ensure the stack does not overflow by limiting recusion depth
recursion_counter: RecursionCounter,
} }
impl<'a> Parser<'a> { impl<'a> Parser<'a> {
/// Parse the specified tokens /// Create a parser for a [`Dialect`]
/// To avoid breaking backwards compatibility, this function accepts ///
/// bare tokens. /// See also [`Parser::parse_sql`]
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self { ///
Parser::new_without_locations(tokens, dialect) /// Example:
/// ```
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
/// # fn main() -> Result<(), ParserError> {
/// let dialect = GenericDialect{};
/// let statements = Parser::new(&dialect)
/// .try_with_sql("SELECT * FROM foo")?
/// .parse_statements()?;
/// # Ok(())
/// # }
/// ```
pub fn new(dialect: &'a dyn Dialect) -> Self {
Self {
tokens: vec![],
index: 0,
dialect,
recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH),
}
} }
pub fn new_without_locations(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self { /// Specify the maximum recursion limit while parsing.
///
///
/// [`Parser`] prevents stack overflows by returning
/// [`ParserError::RecursionLimitExceeded`] if the parser exceeds
/// this depth while processing the query.
///
/// Example:
/// ```
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
/// # fn main() -> Result<(), ParserError> {
/// let dialect = GenericDialect{};
/// let result = Parser::new(&dialect)
/// .with_recursion_limit(1)
/// .try_with_sql("SELECT * FROM foo WHERE (a OR (b OR (c OR d)))")?
/// .parse_statements();
/// assert_eq!(result, Err(ParserError::RecursionLimitExceeded));
/// # Ok(())
/// # }
/// ```
pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self {
self.recursion_counter = RecursionCounter::new(recursion_limit);
self
}
/// Reset this parser to parse the specified token stream
pub fn with_tokens_with_locations(mut self, tokens: Vec<TokenWithLocation>) -> Self {
self.tokens = tokens;
self.index = 0;
self
}
/// Reset this parser state to parse the specified tokens
pub fn with_tokens(self, tokens: Vec<Token>) -> Self {
// Put in dummy locations // Put in dummy locations
let tokens_with_locations: Vec<TokenWithLocation> = tokens let tokens_with_locations: Vec<TokenWithLocation> = tokens
.into_iter() .into_iter()
@ -128,49 +272,84 @@ impl<'a> Parser<'a> {
location: Location { line: 0, column: 0 }, location: Location { line: 0, column: 0 },
}) })
.collect(); .collect();
Parser::new_with_locations(tokens_with_locations, dialect) self.with_tokens_with_locations(tokens_with_locations)
} }
/// Parse the specified tokens /// Tokenize the sql string and sets this [`Parser`]'s state to
pub fn new_with_locations(tokens: Vec<TokenWithLocation>, dialect: &'a dyn Dialect) -> Self { /// parse the resulting tokens
Parser { ///
tokens, /// Returns an error if there was an error tokenizing the SQL string.
index: 0, ///
dialect, /// See example on [`Parser::new()`] for an example
} pub fn try_with_sql(self, sql: &str) -> Result<Self, ParserError> {
} debug!("Parsing sql '{}'...", sql);
let mut tokenizer = Tokenizer::new(self.dialect, sql);
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
let mut tokenizer = Tokenizer::new(dialect, sql);
let tokens = tokenizer.tokenize()?; let tokens = tokenizer.tokenize()?;
let mut parser = Parser::new(tokens, dialect); Ok(self.with_tokens(tokens))
}
/// Parse potentially multiple statements
///
/// Example
/// ```
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
/// # fn main() -> Result<(), ParserError> {
/// let dialect = GenericDialect{};
/// let statements = Parser::new(&dialect)
/// // Parse a SQL string with 2 separate statements
/// .try_with_sql("SELECT * FROM foo; SELECT * FROM bar;")?
/// .parse_statements()?;
/// assert_eq!(statements.len(), 2);
/// # Ok(())
/// # }
/// ```
pub fn parse_statements(&mut self) -> Result<Vec<Statement>, ParserError> {
let mut stmts = Vec::new(); let mut stmts = Vec::new();
let mut expecting_statement_delimiter = false; let mut expecting_statement_delimiter = false;
debug!("Parsing sql '{}'...", sql);
loop { loop {
// ignore empty statements (between successive statement delimiters) // ignore empty statements (between successive statement delimiters)
while parser.consume_token(&Token::SemiColon) { while self.consume_token(&Token::SemiColon) {
expecting_statement_delimiter = false; expecting_statement_delimiter = false;
} }
if parser.peek_token() == Token::EOF { if self.peek_token() == Token::EOF {
break; break;
} }
if expecting_statement_delimiter { if expecting_statement_delimiter {
return parser.expected("end of statement", parser.peek_token()); return self.expected("end of statement", self.peek_token());
} }
let statement = parser.parse_statement()?; let statement = self.parse_statement()?;
stmts.push(statement); stmts.push(statement);
expecting_statement_delimiter = true; expecting_statement_delimiter = true;
} }
Ok(stmts) Ok(stmts)
} }
/// Convience method to parse a string with one or more SQL
/// statements into produce an Abstract Syntax Tree (AST).
///
/// Example
/// ```
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
/// # fn main() -> Result<(), ParserError> {
/// let dialect = GenericDialect{};
/// let statements = Parser::parse_sql(
/// &dialect, "SELECT * FROM foo"
/// )?;
/// assert_eq!(statements.len(), 1);
/// # Ok(())
/// # }
/// ```
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
Parser::new(dialect).try_with_sql(sql)?.parse_statements()
}
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.), /// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.),
/// stopping before the statement separator, if any. /// stopping before the statement separator, if any.
pub fn parse_statement(&mut self) -> Result<Statement, ParserError> { pub fn parse_statement(&mut self) -> Result<Statement, ParserError> {
let _guard = self.recursion_counter.try_decrease()?;
// allow the dialect to override statement parsing // allow the dialect to override statement parsing
if let Some(statement) = self.dialect.parse_statement(self) { if let Some(statement) = self.dialect.parse_statement(self) {
return statement; return statement;
@ -364,6 +543,7 @@ impl<'a> Parser<'a> {
/// Parse a new expression /// Parse a new expression
pub fn parse_expr(&mut self) -> Result<Expr, ParserError> { pub fn parse_expr(&mut self) -> Result<Expr, ParserError> {
let _guard = self.recursion_counter.try_decrease()?;
self.parse_subexpr(0) self.parse_subexpr(0)
} }
@ -4512,6 +4692,7 @@ impl<'a> Parser<'a> {
/// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't /// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't
/// expect the initial keyword to be already consumed /// expect the initial keyword to be already consumed
pub fn parse_query(&mut self) -> Result<Query, ParserError> { pub fn parse_query(&mut self) -> Result<Query, ParserError> {
let _guard = self.recursion_counter.try_decrease()?;
let with = if self.parse_keyword(Keyword::WITH) { let with = if self.parse_keyword(Keyword::WITH) {
Some(With { Some(With {
recursive: self.parse_keyword(Keyword::RECURSIVE), recursive: self.parse_keyword(Keyword::RECURSIVE),

View file

@ -29,7 +29,6 @@ use core::fmt::Debug;
use crate::ast::*; use crate::ast::*;
use crate::dialect::*; use crate::dialect::*;
use crate::parser::{Parser, ParserError}; use crate::parser::{Parser, ParserError};
use crate::tokenizer::Tokenizer;
/// Tests use the methods on this struct to invoke the parser on one or /// Tests use the methods on this struct to invoke the parser on one or
/// multiple dialects. /// multiple dialects.
@ -65,9 +64,8 @@ impl TestedDialects {
F: Fn(&mut Parser) -> T, F: Fn(&mut Parser) -> T,
{ {
self.one_of_identical_results(|dialect| { self.one_of_identical_results(|dialect| {
let mut tokenizer = Tokenizer::new(dialect, sql); let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap();
let tokens = tokenizer.tokenize().unwrap(); f(&mut parser)
f(&mut Parser::new(tokens, dialect))
}) })
} }

View file

@ -6508,3 +6508,91 @@ fn parse_uncache_table() {
res.unwrap_err() res.unwrap_err()
); );
} }
#[test]
fn parse_deeply_nested_parens_hits_recursion_limits() {
let sql = "(".repeat(1000);
let res = parse_sql_statements(&sql);
assert_eq!(ParserError::RecursionLimitExceeded, res.unwrap_err());
}
#[test]
fn parse_deeply_nested_expr_hits_recursion_limits() {
let dialect = GenericDialect {};
let where_clause = make_where_clause(100);
let sql = format!("SELECT id, user_id FROM test WHERE {where_clause}");
let res = Parser::new(&dialect)
.try_with_sql(&sql)
.expect("tokenize to work")
.parse_statements();
assert_eq!(res, Err(ParserError::RecursionLimitExceeded));
}
#[test]
fn parse_deeply_nested_subquery_expr_hits_recursion_limits() {
let dialect = GenericDialect {};
let where_clause = make_where_clause(100);
let sql = format!("SELECT id, user_id where id IN (select id from t WHERE {where_clause})");
let res = Parser::new(&dialect)
.try_with_sql(&sql)
.expect("tokenize to work")
.parse_statements();
assert_eq!(res, Err(ParserError::RecursionLimitExceeded));
}
#[test]
fn parse_with_recursion_limit() {
let dialect = GenericDialect {};
let where_clause = make_where_clause(20);
let sql = format!("SELECT id, user_id FROM test WHERE {where_clause}");
// Expect the statement to parse with default limit
let res = Parser::new(&dialect)
.try_with_sql(&sql)
.expect("tokenize to work")
.parse_statements();
assert!(matches!(res, Ok(_)), "{:?}", res);
// limit recursion to something smaller, expect parsing to fail
let res = Parser::new(&dialect)
.try_with_sql(&sql)
.expect("tokenize to work")
.with_recursion_limit(20)
.parse_statements();
assert_eq!(res, Err(ParserError::RecursionLimitExceeded));
// limit recursion to 50, expect it to succeed
let res = Parser::new(&dialect)
.try_with_sql(&sql)
.expect("tokenize to work")
.with_recursion_limit(50)
.parse_statements();
assert!(matches!(res, Ok(_)), "{:?}", res);
}
/// Makes a predicate that looks like ((user_id = $id) OR user_id = $2...)
fn make_where_clause(num: usize) -> String {
use std::fmt::Write;
let mut output = "(".repeat(num - 1);
for i in 0..num {
if i > 0 {
write!(&mut output, " OR ").unwrap();
}
write!(&mut output, "user_id = {}", i).unwrap();
if i < num - 1 {
write!(&mut output, ")").unwrap();
}
}
output
}