mirror of
https://github.com/apache/datafusion-sqlparser-rs.git
synced 2025-07-07 17:04:59 +00:00
add {pre,post}_visit_query to Visitor (#1044)
This commit is contained in:
parent
640b9394cd
commit
86aa044032
3 changed files with 140 additions and 15 deletions
|
@ -48,33 +48,86 @@ impl Visit for Bar {
|
|||
}
|
||||
```
|
||||
|
||||
Additionally certain types may wish to call a corresponding method on visitor before recursing
|
||||
Some types may wish to call a corresponding method on the visitor:
|
||||
|
||||
```rust
|
||||
#[derive(Visit, VisitMut)]
|
||||
#[visit(with = "visit_expr")]
|
||||
enum Expr {
|
||||
A(),
|
||||
B(String, #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] ObjectName, bool),
|
||||
IsNull(Box<Expr>),
|
||||
..
|
||||
}
|
||||
```
|
||||
|
||||
Will generate
|
||||
This will result in the following sequence of visitor calls when an `IsNull`
|
||||
expression is visited
|
||||
|
||||
```
|
||||
visitor.pre_visit_expr(<is null expr>)
|
||||
visitor.pre_visit_expr(<is null operand>)
|
||||
visitor.post_visit_expr(<is null operand>)
|
||||
visitor.post_visit_expr(<is null expr>)
|
||||
```
|
||||
|
||||
For some types it is only appropriate to call a particular visitor method in
|
||||
some contexts. For example, not every `ObjectName` refers to a relation.
|
||||
|
||||
In these cases, the `visit` attribute can be used on the field for which we'd
|
||||
like to call the method:
|
||||
|
||||
```rust
|
||||
impl Visit for Bar {
|
||||
#[derive(Visit, VisitMut)]
|
||||
#[visit(with = "visit_table_factor")]
|
||||
pub enum TableFactor {
|
||||
Table {
|
||||
#[visit(with = "visit_relation")]
|
||||
name: ObjectName,
|
||||
alias: Option<TableAlias>,
|
||||
},
|
||||
..
|
||||
}
|
||||
```
|
||||
|
||||
This will generate
|
||||
|
||||
```rust
|
||||
impl Visit for TableFactor {
|
||||
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
|
||||
visitor.visit_expr(self)?;
|
||||
visitor.pre_visit_table_factor(self)?;
|
||||
match self {
|
||||
Self::A() => {}
|
||||
Self::B(_1, _2, _3) => {
|
||||
_1.visit(visitor)?;
|
||||
visitor.visit_relation(_3)?;
|
||||
_2.visit(visitor)?;
|
||||
_3.visit(visitor)?;
|
||||
Self::Table { name, alias } => {
|
||||
visitor.pre_visit_relation(name)?;
|
||||
alias.visit(name)?;
|
||||
visitor.post_visit_relation(name)?;
|
||||
alias.visit(visitor)?;
|
||||
}
|
||||
}
|
||||
visitor.post_visit_table_factor(self)?;
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Note that annotating both the type and the field is incorrect as it will result
|
||||
in redundant calls to the method. For example
|
||||
|
||||
```rust
|
||||
#[derive(Visit, VisitMut)]
|
||||
#[visit(with = "visit_expr")]
|
||||
enum Expr {
|
||||
IsNull(#[visit(with = "visit_expr")] Box<Expr>),
|
||||
..
|
||||
}
|
||||
```
|
||||
|
||||
will result in these calls to the visitor
|
||||
|
||||
|
||||
```
|
||||
visitor.pre_visit_expr(<is null expr>)
|
||||
visitor.pre_visit_expr(<is null operand>)
|
||||
visitor.pre_visit_expr(<is null operand>)
|
||||
visitor.post_visit_expr(<is null operand>)
|
||||
visitor.post_visit_expr(<is null operand>)
|
||||
visitor.post_visit_expr(<is null expr>)
|
||||
```
|
||||
|
|
|
@ -26,6 +26,7 @@ use crate::ast::*;
|
|||
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
|
||||
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
|
||||
pub struct Query {
|
||||
/// WITH (common table expressions, or CTEs)
|
||||
pub with: Option<With>,
|
||||
|
@ -739,7 +740,6 @@ pub enum TableFactor {
|
|||
/// For example `FROM monthly_sales PIVOT(sum(amount) FOR MONTH IN ('JAN', 'FEB'))`
|
||||
/// See <https://docs.snowflake.com/en/sql-reference/constructs/pivot>
|
||||
Pivot {
|
||||
#[cfg_attr(feature = "visitor", visit(with = "visit_table_factor"))]
|
||||
table: Box<TableFactor>,
|
||||
aggregate_function: Expr, // Function expression
|
||||
value_column: Vec<Ident>,
|
||||
|
@ -755,7 +755,6 @@ pub enum TableFactor {
|
|||
///
|
||||
/// See <https://docs.snowflake.com/en/sql-reference/constructs/unpivot>.
|
||||
Unpivot {
|
||||
#[cfg_attr(feature = "visitor", visit(with = "visit_table_factor"))]
|
||||
table: Box<TableFactor>,
|
||||
value: Ident,
|
||||
name: Ident,
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
|
||||
|
||||
use crate::ast::{Expr, ObjectName, Statement, TableFactor};
|
||||
use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
|
||||
use core::ops::ControlFlow;
|
||||
|
||||
/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
|
||||
|
@ -179,6 +179,16 @@ pub trait Visitor {
|
|||
/// Type returned when the recursion returns early.
|
||||
type Break;
|
||||
|
||||
/// Invoked for any queries that appear in the AST before visiting children
|
||||
fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
|
||||
/// Invoked for any queries that appear in the AST after visiting children
|
||||
fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
|
||||
/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
|
||||
fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
|
||||
ControlFlow::Continue(())
|
||||
|
@ -267,6 +277,16 @@ pub trait VisitorMut {
|
|||
/// Type returned when the recursion returns early.
|
||||
type Break;
|
||||
|
||||
/// Invoked for any queries that appear in the AST before visiting children
|
||||
fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
|
||||
/// Invoked for any queries that appear in the AST after visiting children
|
||||
fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
|
||||
/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
|
||||
fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
|
||||
ControlFlow::Continue(())
|
||||
|
@ -626,6 +646,18 @@ mod tests {
|
|||
impl Visitor for TestVisitor {
|
||||
type Break = ();
|
||||
|
||||
/// Invoked for any queries that appear in the AST before visiting children
|
||||
fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
|
||||
self.visited.push(format!("PRE: QUERY: {query}"));
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
|
||||
/// Invoked for any queries that appear in the AST after visiting children
|
||||
fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
|
||||
self.visited.push(format!("POST: QUERY: {query}"));
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
|
||||
fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
|
||||
self.visited.push(format!("PRE: RELATION: {relation}"));
|
||||
ControlFlow::Continue(())
|
||||
|
@ -695,10 +727,12 @@ mod tests {
|
|||
"SELECT * from table_name as my_table",
|
||||
vec![
|
||||
"PRE: STATEMENT: SELECT * FROM table_name AS my_table",
|
||||
"PRE: QUERY: SELECT * FROM table_name AS my_table",
|
||||
"PRE: TABLE FACTOR: table_name AS my_table",
|
||||
"PRE: RELATION: table_name",
|
||||
"POST: RELATION: table_name",
|
||||
"POST: TABLE FACTOR: table_name AS my_table",
|
||||
"POST: QUERY: SELECT * FROM table_name AS my_table",
|
||||
"POST: STATEMENT: SELECT * FROM table_name AS my_table",
|
||||
],
|
||||
),
|
||||
|
@ -706,6 +740,7 @@ mod tests {
|
|||
"SELECT * from t1 join t2 on t1.id = t2.t1_id",
|
||||
vec![
|
||||
"PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
|
||||
"PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
|
||||
"PRE: TABLE FACTOR: t1",
|
||||
"PRE: RELATION: t1",
|
||||
"POST: RELATION: t1",
|
||||
|
@ -720,6 +755,7 @@ mod tests {
|
|||
"PRE: EXPR: t2.t1_id",
|
||||
"POST: EXPR: t2.t1_id",
|
||||
"POST: EXPR: t1.id = t2.t1_id",
|
||||
"POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
|
||||
"POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
|
||||
],
|
||||
),
|
||||
|
@ -727,18 +763,22 @@ mod tests {
|
|||
"SELECT * from t1 where EXISTS(SELECT column from t2)",
|
||||
vec![
|
||||
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
|
||||
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
|
||||
"PRE: TABLE FACTOR: t1",
|
||||
"PRE: RELATION: t1",
|
||||
"POST: RELATION: t1",
|
||||
"POST: TABLE FACTOR: t1",
|
||||
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
|
||||
"PRE: QUERY: SELECT column FROM t2",
|
||||
"PRE: EXPR: column",
|
||||
"POST: EXPR: column",
|
||||
"PRE: TABLE FACTOR: t2",
|
||||
"PRE: RELATION: t2",
|
||||
"POST: RELATION: t2",
|
||||
"POST: TABLE FACTOR: t2",
|
||||
"POST: QUERY: SELECT column FROM t2",
|
||||
"POST: EXPR: EXISTS (SELECT column FROM t2)",
|
||||
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
|
||||
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
|
||||
],
|
||||
),
|
||||
|
@ -746,18 +786,22 @@ mod tests {
|
|||
"SELECT * from t1 where EXISTS(SELECT column from t2)",
|
||||
vec![
|
||||
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
|
||||
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
|
||||
"PRE: TABLE FACTOR: t1",
|
||||
"PRE: RELATION: t1",
|
||||
"POST: RELATION: t1",
|
||||
"POST: TABLE FACTOR: t1",
|
||||
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
|
||||
"PRE: QUERY: SELECT column FROM t2",
|
||||
"PRE: EXPR: column",
|
||||
"POST: EXPR: column",
|
||||
"PRE: TABLE FACTOR: t2",
|
||||
"PRE: RELATION: t2",
|
||||
"POST: RELATION: t2",
|
||||
"POST: TABLE FACTOR: t2",
|
||||
"POST: QUERY: SELECT column FROM t2",
|
||||
"POST: EXPR: EXISTS (SELECT column FROM t2)",
|
||||
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
|
||||
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
|
||||
],
|
||||
),
|
||||
|
@ -765,25 +809,54 @@ mod tests {
|
|||
"SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
|
||||
vec![
|
||||
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
|
||||
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
|
||||
"PRE: TABLE FACTOR: t1",
|
||||
"PRE: RELATION: t1",
|
||||
"POST: RELATION: t1",
|
||||
"POST: TABLE FACTOR: t1",
|
||||
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
|
||||
"PRE: QUERY: SELECT column FROM t2",
|
||||
"PRE: EXPR: column",
|
||||
"POST: EXPR: column",
|
||||
"PRE: TABLE FACTOR: t2",
|
||||
"PRE: RELATION: t2",
|
||||
"POST: RELATION: t2",
|
||||
"POST: TABLE FACTOR: t2",
|
||||
"POST: QUERY: SELECT column FROM t2",
|
||||
"POST: EXPR: EXISTS (SELECT column FROM t2)",
|
||||
"PRE: TABLE FACTOR: t3",
|
||||
"PRE: RELATION: t3",
|
||||
"POST: RELATION: t3",
|
||||
"POST: TABLE FACTOR: t3",
|
||||
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
|
||||
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
|
||||
],
|
||||
),
|
||||
(
|
||||
concat!(
|
||||
"SELECT * FROM monthly_sales ",
|
||||
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
|
||||
"ORDER BY EMPID"
|
||||
),
|
||||
vec![
|
||||
"PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
|
||||
"PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
|
||||
"PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
|
||||
"PRE: TABLE FACTOR: monthly_sales",
|
||||
"PRE: RELATION: monthly_sales",
|
||||
"POST: RELATION: monthly_sales",
|
||||
"POST: TABLE FACTOR: monthly_sales",
|
||||
"PRE: EXPR: SUM(a.amount)",
|
||||
"PRE: EXPR: a.amount",
|
||||
"POST: EXPR: a.amount",
|
||||
"POST: EXPR: SUM(a.amount)",
|
||||
"POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
|
||||
"PRE: EXPR: EMPID",
|
||||
"POST: EXPR: EMPID",
|
||||
"POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
|
||||
"POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
|
||||
]
|
||||
)
|
||||
];
|
||||
for (sql, expected) in tests {
|
||||
let actual = do_visit(sql);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue