mirror of
https://github.com/apache/datafusion-sqlparser-rs.git
synced 2025-07-07 17:04:59 +00:00
Add #[recursive]
(#1522)
Co-authored-by: Ifeanyi Ubah <ify1992@yahoo.com>
This commit is contained in:
parent
c973df35d6
commit
84e82e6e2e
8 changed files with 93 additions and 2 deletions
|
@ -37,8 +37,9 @@ name = "sqlparser"
|
||||||
path = "src/lib.rs"
|
path = "src/lib.rs"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["std"]
|
default = ["std", "recursive-protection"]
|
||||||
std = []
|
std = []
|
||||||
|
recursive-protection = ["std", "recursive"]
|
||||||
# Enable JSON output in the `cli` example:
|
# Enable JSON output in the `cli` example:
|
||||||
json_example = ["serde_json", "serde"]
|
json_example = ["serde_json", "serde"]
|
||||||
visitor = ["sqlparser_derive"]
|
visitor = ["sqlparser_derive"]
|
||||||
|
@ -46,6 +47,8 @@ visitor = ["sqlparser_derive"]
|
||||||
[dependencies]
|
[dependencies]
|
||||||
bigdecimal = { version = "0.4.1", features = ["serde"], optional = true }
|
bigdecimal = { version = "0.4.1", features = ["serde"], optional = true }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
|
recursive = { version = "0.1.1", optional = true}
|
||||||
|
|
||||||
serde = { version = "1.0", features = ["derive"], optional = true }
|
serde = { version = "1.0", features = ["derive"], optional = true }
|
||||||
# serde_json is only used in examples/cli, but we have to put it outside
|
# serde_json is only used in examples/cli, but we have to put it outside
|
||||||
# of dev-dependencies because of
|
# of dev-dependencies because of
|
||||||
|
|
|
@ -63,7 +63,7 @@ The following optional [crate features](https://doc.rust-lang.org/cargo/referen
|
||||||
|
|
||||||
* `serde`: Adds [Serde](https://serde.rs/) support by implementing `Serialize` and `Deserialize` for all AST nodes.
|
* `serde`: Adds [Serde](https://serde.rs/) support by implementing `Serialize` and `Deserialize` for all AST nodes.
|
||||||
* `visitor`: Adds a `Visitor` capable of recursively walking the AST tree.
|
* `visitor`: Adds a `Visitor` capable of recursively walking the AST tree.
|
||||||
|
* `recursive-protection` (enabled by default), uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection.
|
||||||
|
|
||||||
## Syntax vs Semantics
|
## Syntax vs Semantics
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,10 @@ fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> proc_
|
||||||
|
|
||||||
let expanded = quote! {
|
let expanded = quote! {
|
||||||
// The generated impl.
|
// The generated impl.
|
||||||
|
// Note that it uses [`recursive::recursive`] to protect from stack overflow.
|
||||||
|
// See tests in https://github.com/apache/datafusion-sqlparser-rs/pull/1522/ for more info.
|
||||||
impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause {
|
impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause {
|
||||||
|
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
|
||||||
fn visit<V: sqlparser::ast::#visitor_trait>(
|
fn visit<V: sqlparser::ast::#visitor_trait>(
|
||||||
&#modifier self,
|
&#modifier self,
|
||||||
visitor: &mut V
|
visitor: &mut V
|
||||||
|
|
|
@ -42,6 +42,46 @@ fn basic_queries(c: &mut Criterion) {
|
||||||
group.bench_function("sqlparser::with_select", |b| {
|
group.bench_function("sqlparser::with_select", |b| {
|
||||||
b.iter(|| Parser::parse_sql(&dialect, with_query).unwrap());
|
b.iter(|| Parser::parse_sql(&dialect, with_query).unwrap());
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let large_statement = {
|
||||||
|
let expressions = (0..1000)
|
||||||
|
.map(|n| format!("FN_{}(COL_{})", n, n))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ");
|
||||||
|
let tables = (0..1000)
|
||||||
|
.map(|n| format!("TABLE_{}", n))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" JOIN ");
|
||||||
|
let where_condition = (0..1000)
|
||||||
|
.map(|n| format!("COL_{} = {}", n, n))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" OR ");
|
||||||
|
let order_condition = (0..1000)
|
||||||
|
.map(|n| format!("COL_{} DESC", n))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ");
|
||||||
|
|
||||||
|
format!(
|
||||||
|
"SELECT {} FROM {} WHERE {} ORDER BY {}",
|
||||||
|
expressions, tables, where_condition, order_condition
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
group.bench_function("parse_large_statement", |b| {
|
||||||
|
b.iter(|| Parser::parse_sql(&dialect, criterion::black_box(large_statement.as_str())));
|
||||||
|
});
|
||||||
|
|
||||||
|
let large_statement = Parser::parse_sql(&dialect, large_statement.as_str())
|
||||||
|
.unwrap()
|
||||||
|
.pop()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
group.bench_function("format_large_statement", |b| {
|
||||||
|
b.iter(|| {
|
||||||
|
let formatted_query = large_statement.to_string();
|
||||||
|
assert_eq!(formatted_query, large_statement);
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
criterion_group!(benches, basic_queries);
|
criterion_group!(benches, basic_queries);
|
||||||
|
|
|
@ -1291,6 +1291,7 @@ impl fmt::Display for CastFormat {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for Expr {
|
impl fmt::Display for Expr {
|
||||||
|
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
Expr::Identifier(s) => write!(f, "{s}"),
|
Expr::Identifier(s) => write!(f, "{s}"),
|
||||||
|
|
|
@ -894,4 +894,29 @@ mod tests {
|
||||||
assert_eq!(actual, expected)
|
assert_eq!(actual, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes
|
||||||
|
|
||||||
|
impl Visitor for QuickVisitor {
|
||||||
|
type Break = ();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn overflow() {
|
||||||
|
let cond = (0..1000)
|
||||||
|
.map(|n| format!("X = {}", n))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" OR ");
|
||||||
|
let sql = format!("SELECT x where {0}", cond);
|
||||||
|
|
||||||
|
let dialect = GenericDialect {};
|
||||||
|
let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap();
|
||||||
|
let s = Parser::new(&dialect)
|
||||||
|
.with_tokens(tokens)
|
||||||
|
.parse_statement()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut visitor = QuickVisitor {};
|
||||||
|
s.visit(&mut visitor);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,6 +73,9 @@ mod recursion {
|
||||||
/// Note: Uses an [`std::rc::Rc`] and [`std::cell::Cell`] in order to satisfy the Rust
|
/// Note: Uses an [`std::rc::Rc`] and [`std::cell::Cell`] in order to satisfy the Rust
|
||||||
/// borrow checker so the automatic [`DepthGuard`] decrement a
|
/// borrow checker so the automatic [`DepthGuard`] decrement a
|
||||||
/// reference to the counter.
|
/// reference to the counter.
|
||||||
|
///
|
||||||
|
/// Note: when "recursive-protection" feature is enabled, this crate uses additional stack overflow protection
|
||||||
|
/// for some of its recursive methods. See [`recursive::recursive`] for more information.
|
||||||
pub(crate) struct RecursionCounter {
|
pub(crate) struct RecursionCounter {
|
||||||
remaining_depth: Rc<Cell<usize>>,
|
remaining_depth: Rc<Cell<usize>>,
|
||||||
}
|
}
|
||||||
|
@ -326,6 +329,9 @@ impl<'a> Parser<'a> {
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
|
///
|
||||||
|
/// Note: when "recursive-protection" feature is enabled, this crate uses additional stack overflow protection
|
||||||
|
// for some of its recursive methods. See [`recursive::recursive`] for more information.
|
||||||
pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self {
|
pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self {
|
||||||
self.recursion_counter = RecursionCounter::new(recursion_limit);
|
self.recursion_counter = RecursionCounter::new(recursion_limit);
|
||||||
self
|
self
|
||||||
|
|
|
@ -12433,3 +12433,16 @@ fn test_table_sample() {
|
||||||
dialects.verified_stmt("SELECT * FROM tbl AS t TABLESAMPLE SYSTEM (50)");
|
dialects.verified_stmt("SELECT * FROM tbl AS t TABLESAMPLE SYSTEM (50)");
|
||||||
dialects.verified_stmt("SELECT * FROM tbl AS t TABLESAMPLE SYSTEM (50) REPEATABLE (10)");
|
dialects.verified_stmt("SELECT * FROM tbl AS t TABLESAMPLE SYSTEM (50) REPEATABLE (10)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn overflow() {
|
||||||
|
let expr = std::iter::repeat("1")
|
||||||
|
.take(1000)
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" + ");
|
||||||
|
let sql = format!("SELECT {}", expr);
|
||||||
|
|
||||||
|
let mut statements = Parser::parse_sql(&GenericDialect {}, sql.as_str()).unwrap();
|
||||||
|
let statement = statements.pop().unwrap();
|
||||||
|
assert_eq!(statement.to_string(), sql);
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue