Extend Visitor trait for Value type (#1725)

This commit is contained in:
tomershaniii 2025-02-22 07:48:39 +02:00 committed by GitHub
parent 3ace97c0ef
commit 8fc8082e9a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 98 additions and 8 deletions

View file

@ -33,7 +33,12 @@ use sqlparser_derive::{Visit, VisitMut};
/// Primitive SQL values such as number and string
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
#[cfg_attr(
feature = "visitor",
derive(Visit, VisitMut),
visit(with = "visit_value")
)]
pub enum Value {
/// Numeric literal
#[cfg(not(feature = "bigdecimal"))]

View file

@ -17,7 +17,7 @@
//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor, Value};
use core::ops::ControlFlow;
/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
@ -233,6 +233,16 @@ pub trait Visitor {
fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
/// Invoked for any Value that appear in the AST before visiting children
fn pre_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
/// Invoked for any Value that appear in the AST after visiting children
fn post_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
}
/// A visitor that can be used to mutate an AST tree.
@ -337,6 +347,16 @@ pub trait VisitorMut {
fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
/// Invoked for any value that appear in the AST before visiting children
fn pre_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
/// Invoked for any statements that appear in the AST after visiting children
fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
}
struct RelationVisitor<F>(F);
@ -647,6 +667,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::Statement;
use crate::dialect::GenericDialect;
use crate::parser::Parser;
use crate::tokenizer::Tokenizer;
@ -720,7 +741,7 @@ mod tests {
}
}
fn do_visit(sql: &str) -> Vec<String> {
fn do_visit<V: Visitor>(sql: &str, visitor: &mut V) -> Statement {
let dialect = GenericDialect {};
let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
let s = Parser::new(&dialect)
@ -728,9 +749,8 @@ mod tests {
.parse_statement()
.unwrap();
let mut visitor = TestVisitor::default();
s.visit(&mut visitor);
visitor.visited
s.visit(visitor);
s
}
#[test]
@ -889,8 +909,9 @@ mod tests {
),
];
for (sql, expected) in tests {
let actual = do_visit(sql);
let actual: Vec<_> = actual.iter().map(|x| x.as_str()).collect();
let mut visitor = TestVisitor::default();
let _ = do_visit(sql, &mut visitor);
let actual: Vec<_> = visitor.visited.iter().map(|x| x.as_str()).collect();
assert_eq!(actual, expected)
}
}
@ -920,3 +941,67 @@ mod tests {
s.visit(&mut visitor);
}
}
#[cfg(test)]
mod visit_mut_tests {
use crate::ast::{Statement, Value, VisitMut, VisitorMut};
use crate::dialect::GenericDialect;
use crate::parser::Parser;
use crate::tokenizer::Tokenizer;
use core::ops::ControlFlow;
#[derive(Default)]
struct MutatorVisitor {
index: u64,
}
impl VisitorMut for MutatorVisitor {
type Break = ();
fn pre_visit_value(&mut self, value: &mut Value) -> ControlFlow<Self::Break> {
self.index += 1;
*value = Value::SingleQuotedString(format!("REDACTED_{}", self.index));
ControlFlow::Continue(())
}
fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
}
fn do_visit_mut<V: VisitorMut>(sql: &str, visitor: &mut V) -> Statement {
let dialect = GenericDialect {};
let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
let mut s = Parser::new(&dialect)
.with_tokens(tokens)
.parse_statement()
.unwrap();
s.visit(visitor);
s
}
#[test]
fn test_value_redact() {
let tests = vec![
(
concat!(
"SELECT * FROM monthly_sales ",
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
"ORDER BY EMPID"
),
concat!(
"SELECT * FROM monthly_sales ",
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('REDACTED_1', 'REDACTED_2', 'REDACTED_3', 'REDACTED_4')) AS p (c, d) ",
"ORDER BY EMPID"
),
),
];
for (sql, expected) in tests {
let mut visitor = MutatorVisitor::default();
let mutated = do_visit_mut(sql, &mut visitor);
assert_eq!(mutated.to_string(), expected)
}
}
}