add {pre,post}_visit_query to Visitor (#1044)

This commit is contained in:
Joey Hain 2023-11-27 09:11:27 -08:00 committed by GitHub
parent 640b9394cd
commit 86aa044032
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 140 additions and 15 deletions

View file

@ -48,33 +48,86 @@ impl Visit for Bar {
}
```
Additionally certain types may wish to call a corresponding method on visitor before recursing
Some types may wish to call a corresponding method on the visitor:
```rust
#[derive(Visit, VisitMut)]
#[visit(with = "visit_expr")]
enum Expr {
A(),
B(String, #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] ObjectName, bool),
IsNull(Box<Expr>),
..
}
```
Will generate
This will result in the following sequence of visitor calls when an `IsNull`
expression is visited
```
visitor.pre_visit_expr(<is null expr>)
visitor.pre_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null expr>)
```
For some types it is only appropriate to call a particular visitor method in
some contexts. For example, not every `ObjectName` refers to a relation.
In these cases, the `visit` attribute can be used on the field for which we'd
like to call the method:
```rust
impl Visit for Bar {
#[derive(Visit, VisitMut)]
#[visit(with = "visit_table_factor")]
pub enum TableFactor {
Table {
#[visit(with = "visit_relation")]
name: ObjectName,
alias: Option<TableAlias>,
},
..
}
```
This will generate
```rust
impl Visit for TableFactor {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
visitor.visit_expr(self)?;
visitor.pre_visit_table_factor(self)?;
match self {
Self::A() => {}
Self::B(_1, _2, _3) => {
_1.visit(visitor)?;
visitor.visit_relation(_3)?;
_2.visit(visitor)?;
_3.visit(visitor)?;
Self::Table { name, alias } => {
visitor.pre_visit_relation(name)?;
alias.visit(name)?;
visitor.post_visit_relation(name)?;
alias.visit(visitor)?;
}
}
visitor.post_visit_table_factor(self)?;
ControlFlow::Continue(())
}
}
```
Note that annotating both the type and the field is incorrect as it will result
in redundant calls to the method. For example
```rust
#[derive(Visit, VisitMut)]
#[visit(with = "visit_expr")]
enum Expr {
IsNull(#[visit(with = "visit_expr")] Box<Expr>),
..
}
```
will result in these calls to the visitor
```
visitor.pre_visit_expr(<is null expr>)
visitor.pre_visit_expr(<is null operand>)
visitor.pre_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null expr>)
```

View file

@ -26,6 +26,7 @@ use crate::ast::*;
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
pub struct Query {
/// WITH (common table expressions, or CTEs)
pub with: Option<With>,
@ -739,7 +740,6 @@ pub enum TableFactor {
/// For example `FROM monthly_sales PIVOT(sum(amount) FOR MONTH IN ('JAN', 'FEB'))`
/// See <https://docs.snowflake.com/en/sql-reference/constructs/pivot>
Pivot {
#[cfg_attr(feature = "visitor", visit(with = "visit_table_factor"))]
table: Box<TableFactor>,
aggregate_function: Expr, // Function expression
value_column: Vec<Ident>,
@ -755,7 +755,6 @@ pub enum TableFactor {
///
/// See <https://docs.snowflake.com/en/sql-reference/constructs/unpivot>.
Unpivot {
#[cfg_attr(feature = "visitor", visit(with = "visit_table_factor"))]
table: Box<TableFactor>,
value: Ident,
name: Ident,

View file

@ -12,7 +12,7 @@
//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
use crate::ast::{Expr, ObjectName, Statement, TableFactor};
use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
use core::ops::ControlFlow;
/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
@ -179,6 +179,16 @@ pub trait Visitor {
/// Type returned when the recursion returns early.
type Break;
/// Invoked for any queries that appear in the AST before visiting children
fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
/// Invoked for any queries that appear in the AST after visiting children
fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
@ -267,6 +277,16 @@ pub trait VisitorMut {
/// Type returned when the recursion returns early.
type Break;
/// Invoked for any queries that appear in the AST before visiting children
fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
/// Invoked for any queries that appear in the AST after visiting children
fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
@ -626,6 +646,18 @@ mod tests {
impl Visitor for TestVisitor {
type Break = ();
/// Invoked for any queries that appear in the AST before visiting children
fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
self.visited.push(format!("PRE: QUERY: {query}"));
ControlFlow::Continue(())
}
/// Invoked for any queries that appear in the AST after visiting children
fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
self.visited.push(format!("POST: QUERY: {query}"));
ControlFlow::Continue(())
}
fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
self.visited.push(format!("PRE: RELATION: {relation}"));
ControlFlow::Continue(())
@ -695,10 +727,12 @@ mod tests {
"SELECT * from table_name as my_table",
vec![
"PRE: STATEMENT: SELECT * FROM table_name AS my_table",
"PRE: QUERY: SELECT * FROM table_name AS my_table",
"PRE: TABLE FACTOR: table_name AS my_table",
"PRE: RELATION: table_name",
"POST: RELATION: table_name",
"POST: TABLE FACTOR: table_name AS my_table",
"POST: QUERY: SELECT * FROM table_name AS my_table",
"POST: STATEMENT: SELECT * FROM table_name AS my_table",
],
),
@ -706,6 +740,7 @@ mod tests {
"SELECT * from t1 join t2 on t1.id = t2.t1_id",
vec![
"PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
"PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
@ -720,6 +755,7 @@ mod tests {
"PRE: EXPR: t2.t1_id",
"POST: EXPR: t2.t1_id",
"POST: EXPR: t1.id = t2.t1_id",
"POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
"POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
],
),
@ -727,18 +763,22 @@ mod tests {
"SELECT * from t1 where EXISTS(SELECT column from t2)",
vec![
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
"POST: TABLE FACTOR: t1",
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT column FROM t2",
"PRE: EXPR: column",
"POST: EXPR: column",
"PRE: TABLE FACTOR: t2",
"PRE: RELATION: t2",
"POST: RELATION: t2",
"POST: TABLE FACTOR: t2",
"POST: QUERY: SELECT column FROM t2",
"POST: EXPR: EXISTS (SELECT column FROM t2)",
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
],
),
@ -746,18 +786,22 @@ mod tests {
"SELECT * from t1 where EXISTS(SELECT column from t2)",
vec![
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
"POST: TABLE FACTOR: t1",
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT column FROM t2",
"PRE: EXPR: column",
"POST: EXPR: column",
"PRE: TABLE FACTOR: t2",
"PRE: RELATION: t2",
"POST: RELATION: t2",
"POST: TABLE FACTOR: t2",
"POST: QUERY: SELECT column FROM t2",
"POST: EXPR: EXISTS (SELECT column FROM t2)",
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
],
),
@ -765,25 +809,54 @@ mod tests {
"SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
vec![
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
"POST: TABLE FACTOR: t1",
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT column FROM t2",
"PRE: EXPR: column",
"POST: EXPR: column",
"PRE: TABLE FACTOR: t2",
"PRE: RELATION: t2",
"POST: RELATION: t2",
"POST: TABLE FACTOR: t2",
"POST: QUERY: SELECT column FROM t2",
"POST: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: TABLE FACTOR: t3",
"PRE: RELATION: t3",
"POST: RELATION: t3",
"POST: TABLE FACTOR: t3",
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
],
),
(
concat!(
"SELECT * FROM monthly_sales ",
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
"ORDER BY EMPID"
),
vec![
"PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
"PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
"PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
"PRE: TABLE FACTOR: monthly_sales",
"PRE: RELATION: monthly_sales",
"POST: RELATION: monthly_sales",
"POST: TABLE FACTOR: monthly_sales",
"PRE: EXPR: SUM(a.amount)",
"PRE: EXPR: a.amount",
"POST: EXPR: a.amount",
"POST: EXPR: SUM(a.amount)",
"POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
"PRE: EXPR: EMPID",
"POST: EXPR: EMPID",
"POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
"POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
]
)
];
for (sql, expected) in tests {
let actual = do_visit(sql);