Add a mutable visitor (#782)

* Add a mutable visitor

This adds the ability to mutate parsed sql queries.
Previously, only visitors taking an immutable reference to the visited structures were allowed.

* add utility functions for mutable visits

* bump version numbers
This commit is contained in:
Ophir LOJKINE 2023-01-02 12:48:20 +01:00 committed by GitHub
parent 86d71f2109
commit 524b8a7e7b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 428 additions and 150 deletions

View file

@ -33,7 +33,7 @@ serde = { version = "1.0", features = ["derive"], optional = true }
# of dev-dependencies because of
# https://github.com/rust-lang/cargo/issues/1596
serde_json = { version = "1.0", optional = true }
sqlparser_derive = { version = "0.1", path = "derive", optional = true }
sqlparser_derive = { version = "0.1.1", path = "derive", optional = true }
[dev-dependencies]
simple_logger = "4.0"

View file

@ -1,7 +1,7 @@
[package]
name = "sqlparser_derive"
description = "proc macro for sqlparser"
version = "0.1.0"
version = "0.1.1"
authors = ["sqlparser-rs authors"]
homepage = "https://github.com/sqlparser-rs/sqlparser-rs"
documentation = "https://docs.rs/sqlparser_derive/"

View file

@ -6,13 +6,13 @@ This crate contains a procedural macro that can automatically derive
implementations of the `Visit` trait in the [sqlparser](https://crates.io/crates/sqlparser) crate
```rust
#[derive(Visit)]
#[derive(Visit, VisitMut)]
struct Foo {
boolean: bool,
bar: Bar,
}
#[derive(Visit)]
#[derive(Visit, VisitMut)]
enum Bar {
A(),
B(String, bool),
@ -51,7 +51,7 @@ impl Visit for Bar {
Additionally certain types may wish to call a corresponding method on visitor before recursing
```rust
#[derive(Visit)]
#[derive(Visit, VisitMut)]
#[visit(with = "visit_expr")]
enum Expr {
A(),

View file

@ -6,25 +6,58 @@ use syn::{
Ident, Index, Lit, Meta, MetaNameValue, NestedMeta,
};
/// Implementation of `[#derive(Visit)]`
#[proc_macro_derive(VisitMut, attributes(visit))]
pub fn derive_visit_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
derive_visit(input, &VisitType {
visit_trait: quote!(VisitMut),
visitor_trait: quote!(VisitorMut),
modifier: Some(quote!(mut)),
})
}
/// Implementation of `[#derive(Visit)]`
#[proc_macro_derive(Visit, attributes(visit))]
pub fn derive_visit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
pub fn derive_visit_immutable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
derive_visit(input, &VisitType {
visit_trait: quote!(Visit),
visitor_trait: quote!(Visitor),
modifier: None,
})
}
struct VisitType {
visit_trait: TokenStream,
visitor_trait: TokenStream,
modifier: Option<TokenStream>,
}
fn derive_visit(
input: proc_macro::TokenStream,
visit_type: &VisitType,
) -> proc_macro::TokenStream {
// Parse the input tokens into a syntax tree.
let input = parse_macro_input!(input as DeriveInput);
let name = input.ident;
let VisitType { visit_trait, visitor_trait, modifier } = visit_type;
let attributes = Attributes::parse(&input.attrs);
// Add a bound `T: HeapSize` to every type parameter T.
let generics = add_trait_bounds(input.generics);
// Add a bound `T: Visit` to every type parameter T.
let generics = add_trait_bounds(input.generics, visit_type);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let (pre_visit, post_visit) = attributes.visit(quote!(self));
let children = visit_children(&input.data);
let children = visit_children(&input.data, visit_type);
let expanded = quote! {
// The generated impl.
impl #impl_generics sqlparser::ast::Visit for #name #ty_generics #where_clause {
fn visit<V: sqlparser::ast::Visitor>(&self, visitor: &mut V) -> ::std::ops::ControlFlow<V::Break> {
impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause {
fn visit<V: sqlparser::ast::#visitor_trait>(
&#modifier self,
visitor: &mut V
) -> ::std::ops::ControlFlow<V::Break> {
#pre_visit
#children
#post_visit
@ -92,25 +125,25 @@ impl Attributes {
}
// Add a bound `T: Visit` to every type parameter T.
fn add_trait_bounds(mut generics: Generics) -> Generics {
fn add_trait_bounds(mut generics: Generics, VisitType{visit_trait, ..}: &VisitType) -> Generics {
for param in &mut generics.params {
if let GenericParam::Type(ref mut type_param) = *param {
type_param.bounds.push(parse_quote!(sqlparser::ast::Visit));
type_param.bounds.push(parse_quote!(sqlparser::ast::#visit_trait));
}
}
generics
}
// Generate the body of the visit implementation for the given type
fn visit_children(data: &Data) -> TokenStream {
fn visit_children(data: &Data, VisitType{visit_trait, modifier, ..}: &VisitType) -> TokenStream {
match data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => {
let recurse = fields.named.iter().map(|f| {
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#name, visitor)?; #post_visit)
let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit)
});
quote! {
#(#recurse)*
@ -121,7 +154,7 @@ fn visit_children(data: &Data) -> TokenStream {
let index = Index::from(i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#index, visitor)?; #post_visit)
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?; #post_visit)
});
quote! {
#(#recurse)*
@ -140,8 +173,8 @@ fn visit_children(data: &Data) -> TokenStream {
let visit = fields.named.iter().map(|f| {
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
});
quote!(
@ -155,8 +188,8 @@ fn visit_children(data: &Data) -> TokenStream {
let visit = fields.unnamed.iter().enumerate().map(|(i, f)| {
let name = format_ident!("_{}", i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
});
quote! {

View file

@ -18,7 +18,7 @@ use core::fmt;
use serde::{Deserialize, Serialize};
#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};
use crate::ast::ObjectName;
@ -27,7 +27,7 @@ use super::value::escape_single_quote_string;
/// SQL data types
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum DataType {
/// Fixed-length character type e.g. CHARACTER(10)
Character(Option<CharacterLength>),
@ -341,7 +341,7 @@ fn format_datetime_precision_and_tz(
/// guarantee compatibility with the input query we must maintain its exact information.
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TimezoneInfo {
/// No information about time zone. E.g., TIMESTAMP
None,
@ -389,7 +389,7 @@ impl fmt::Display for TimezoneInfo {
/// [standard]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#exact-numeric-type
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ExactNumberInfo {
/// No additional information e.g. `DECIMAL`
None,
@ -420,7 +420,7 @@ impl fmt::Display for ExactNumberInfo {
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#character-length
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CharacterLength {
/// Default (if VARYING) or maximum (if not VARYING) length
pub length: u64,
@ -443,7 +443,7 @@ impl fmt::Display for CharacterLength {
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#char-length-units
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum CharLengthUnits {
/// CHARACTERS unit
Characters,

View file

@ -21,7 +21,7 @@ use core::fmt;
use serde::{Deserialize, Serialize};
#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};
use crate::ast::value::escape_single_quote_string;
use crate::ast::{display_comma_separated, display_separated, DataType, Expr, Ident, ObjectName};
@ -30,7 +30,7 @@ use crate::tokenizer::Token;
/// An `ALTER TABLE` (`Statement::AlterTable`) operation
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum AlterTableOperation {
/// `ADD <table_constraint>`
AddConstraint(TableConstraint),
@ -100,7 +100,7 @@ pub enum AlterTableOperation {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum AlterIndexOperation {
RenameIndex { index_name: ObjectName },
}
@ -224,7 +224,7 @@ impl fmt::Display for AlterIndexOperation {
/// An `ALTER COLUMN` (`Statement::AlterTable`) operation
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum AlterColumnOperation {
/// `SET NOT NULL`
SetNotNull,
@ -268,7 +268,7 @@ impl fmt::Display for AlterColumnOperation {
/// `ALTER TABLE ADD <constraint>` statement.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TableConstraint {
/// `[ CONSTRAINT <name> ] { PRIMARY KEY | UNIQUE } (<columns>)`
Unique {
@ -433,7 +433,7 @@ impl fmt::Display for TableConstraint {
/// [1]: https://dev.mysql.com/doc/refman/8.0/en/create-table.html
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum KeyOrIndexDisplay {
/// Nothing to display
None,
@ -469,7 +469,7 @@ impl fmt::Display for KeyOrIndexDisplay {
/// [3]: https://www.postgresql.org/docs/14/sql-createindex.html
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum IndexType {
BTree,
Hash,
@ -488,7 +488,7 @@ impl fmt::Display for IndexType {
/// SQL column definition
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ColumnDef {
pub name: Ident,
pub data_type: DataType,
@ -524,7 +524,7 @@ impl fmt::Display for ColumnDef {
/// "column options," and we allow any column option to be named.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ColumnOptionDef {
pub name: Option<Ident>,
pub option: ColumnOption,
@ -540,7 +540,7 @@ impl fmt::Display for ColumnOptionDef {
/// TABLE` statement.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ColumnOption {
/// `NULL`
Null,
@ -630,7 +630,7 @@ fn display_constraint_name(name: &'_ Option<Ident>) -> impl fmt::Display + '_ {
/// Used in foreign key constraints in `ON UPDATE` and `ON DELETE` options.
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ReferentialAction {
Restrict,
Cascade,

View file

@ -5,7 +5,7 @@ use alloc::{boxed::Box, format, string::String, vec, vec::Vec};
use serde::{Deserialize, Serialize};
#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};
use crate::ast::{
ColumnDef, FileFormat, HiveDistributionStyle, HiveFormat, ObjectName, OnCommit, Query,
@ -43,7 +43,7 @@ use crate::parser::ParserError;
/// [1]: crate::ast::Statement::CreateTable
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CreateTableBuilder {
pub or_replace: bool,
pub temporary: bool,

View file

@ -23,7 +23,7 @@ use core::fmt;
use serde::{Deserialize, Serialize};
#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};
pub use self::data_type::{
CharLengthUnits, CharacterLength, DataType, ExactNumberInfo, TimezoneInfo,
@ -96,7 +96,7 @@ where
/// An identifier, decomposed into its value or character data and the quote style.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Ident {
/// The value of the identifier without quotes.
pub value: String,
@ -157,7 +157,7 @@ impl fmt::Display for Ident {
/// A name of a table, view, custom type, etc., possibly multi-part, i.e. db.schema.obj
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ObjectName(pub Vec<Ident>);
impl fmt::Display for ObjectName {
@ -170,7 +170,7 @@ impl fmt::Display for ObjectName {
/// `ARRAY[..]`, or `[..]`
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Array {
/// The list of expressions between brackets
pub elem: Vec<Expr>,
@ -193,7 +193,7 @@ impl fmt::Display for Array {
/// JsonOperator
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum JsonOperator {
/// -> keeps the value as json
Arrow,
@ -257,7 +257,11 @@ impl fmt::Display for JsonOperator {
/// inappropriate type, like `WHERE 1` or `SELECT 1=1`, as necessary.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit), visit(with = "visit_expr"))]
#[cfg_attr(
feature = "visitor",
derive(Visit, VisitMut),
visit(with = "visit_expr")
)]
pub enum Expr {
/// Identifier e.g. table name or column name
Identifier(Ident),
@ -911,7 +915,7 @@ impl fmt::Display for Expr {
/// A window specification (i.e. `OVER (PARTITION BY .. ORDER BY .. etc.)`)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct WindowSpec {
pub partition_by: Vec<Expr>,
pub order_by: Vec<OrderByExpr>,
@ -957,7 +961,7 @@ impl fmt::Display for WindowSpec {
/// reject invalid bounds like `ROWS UNBOUNDED FOLLOWING` before execution.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct WindowFrame {
pub units: WindowFrameUnits,
pub start_bound: WindowFrameBound,
@ -983,7 +987,7 @@ impl Default for WindowFrame {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum WindowFrameUnits {
Rows,
Range,
@ -1003,7 +1007,7 @@ impl fmt::Display for WindowFrameUnits {
/// Specifies [WindowFrame]'s `start_bound` and `end_bound`
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum WindowFrameBound {
/// `CURRENT ROW`
CurrentRow,
@ -1027,7 +1031,7 @@ impl fmt::Display for WindowFrameBound {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum AddDropSync {
ADD,
DROP,
@ -1046,7 +1050,7 @@ impl fmt::Display for AddDropSync {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ShowCreateObject {
Event,
Function,
@ -1071,7 +1075,7 @@ impl fmt::Display for ShowCreateObject {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum CommentObject {
Column,
Table,
@ -1088,7 +1092,7 @@ impl fmt::Display for CommentObject {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum Password {
Password(Expr),
NullPassword,
@ -1098,7 +1102,11 @@ pub enum Password {
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit), visit(with = "visit_statement"))]
#[cfg_attr(
feature = "visitor",
derive(Visit, VisitMut),
visit(with = "visit_statement")
)]
pub enum Statement {
/// Analyze (Hive)
Analyze {
@ -2739,7 +2747,7 @@ impl fmt::Display for Statement {
/// ```
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum SequenceOptions {
IncrementBy(Expr, bool),
MinValue(MinMaxValue),
@ -2804,7 +2812,7 @@ impl fmt::Display for SequenceOptions {
/// [ MINVALUE minvalue | NO MINVALUE ] [ MAXVALUE maxvalue | NO MAXVALUE ]
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum MinMaxValue {
// clause is not specified
Empty,
@ -2816,7 +2824,7 @@ pub enum MinMaxValue {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
#[non_exhaustive]
pub enum OnInsert {
/// ON DUPLICATE KEY UPDATE (MySQL when the key already exists, then execute an update instead)
@ -2827,21 +2835,21 @@ pub enum OnInsert {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct OnConflict {
pub conflict_target: Option<ConflictTarget>,
pub action: OnConflictAction,
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ConflictTarget {
Columns(Vec<Ident>),
OnConstraint(ObjectName),
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum OnConflictAction {
DoNothing,
DoUpdate(DoUpdate),
@ -2849,7 +2857,7 @@ pub enum OnConflictAction {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct DoUpdate {
/// Column assignments
pub assignments: Vec<Assignment>,
@ -2911,7 +2919,7 @@ impl fmt::Display for OnConflictAction {
/// Privileges granted in a GRANT statement or revoked in a REVOKE statement.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum Privileges {
/// All privileges applicable to the object type
All {
@ -2948,7 +2956,7 @@ impl fmt::Display for Privileges {
/// Specific direction for FETCH statement
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FetchDirection {
Count { limit: Value },
Next,
@ -3012,7 +3020,7 @@ impl fmt::Display for FetchDirection {
/// A privilege on a database object (table, sequence, etc.).
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum Action {
Connect,
Create,
@ -3062,7 +3070,7 @@ impl fmt::Display for Action {
/// Objects on which privileges are granted in a GRANT statement.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum GrantObjects {
/// Grant privileges on `ALL SEQUENCES IN SCHEMA <schema_name> [, ...]`
AllSequencesInSchema { schemas: Vec<ObjectName> },
@ -3109,7 +3117,7 @@ impl fmt::Display for GrantObjects {
/// SQL assignment `foo = expr` as used in SQLUpdate
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Assignment {
pub id: Vec<Ident>,
pub value: Expr,
@ -3123,7 +3131,7 @@ impl fmt::Display for Assignment {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FunctionArgExpr {
Expr(Expr),
/// Qualified wildcard, e.g. `alias.*` or `schema.table.*`.
@ -3144,7 +3152,7 @@ impl fmt::Display for FunctionArgExpr {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FunctionArg {
Named { name: Ident, arg: FunctionArgExpr },
Unnamed(FunctionArgExpr),
@ -3161,7 +3169,7 @@ impl fmt::Display for FunctionArg {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum CloseCursor {
All,
Specific { name: Ident },
@ -3179,7 +3187,7 @@ impl fmt::Display for CloseCursor {
/// A function call
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Function {
pub name: ObjectName,
pub args: Vec<FunctionArg>,
@ -3193,7 +3201,7 @@ pub struct Function {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum AnalyzeFormat {
TEXT,
GRAPHVIZ,
@ -3235,7 +3243,7 @@ impl fmt::Display for Function {
/// External table's available file format
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FileFormat {
TEXTFILE,
SEQUENCEFILE,
@ -3265,7 +3273,7 @@ impl fmt::Display for FileFormat {
/// [ WITHIN GROUP (ORDER BY <within_group1>[, ...] ) ]`
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ListAgg {
pub distinct: bool,
pub expr: Box<Expr>,
@ -3303,7 +3311,7 @@ impl fmt::Display for ListAgg {
/// The `ON OVERFLOW` clause of a LISTAGG invocation
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ListAggOnOverflow {
/// `ON OVERFLOW ERROR`
Error,
@ -3341,7 +3349,7 @@ impl fmt::Display for ListAggOnOverflow {
/// ORDER BY position is defined differently for BigQuery, Postgres and Snowflake.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ArrayAgg {
pub distinct: bool,
pub expr: Box<Expr>,
@ -3378,7 +3386,7 @@ impl fmt::Display for ArrayAgg {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ObjectType {
Table,
View,
@ -3403,7 +3411,7 @@ impl fmt::Display for ObjectType {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum KillType {
Connection,
Query,
@ -3424,7 +3432,7 @@ impl fmt::Display for KillType {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum HiveDistributionStyle {
PARTITIONED {
columns: Vec<ColumnDef>,
@ -3444,7 +3452,7 @@ pub enum HiveDistributionStyle {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum HiveRowFormat {
SERDE { class: String },
DELIMITED,
@ -3452,7 +3460,7 @@ pub enum HiveRowFormat {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
#[allow(clippy::large_enum_variant)]
pub enum HiveIOFormat {
IOF {
@ -3466,7 +3474,7 @@ pub enum HiveIOFormat {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct HiveFormat {
pub row_format: Option<HiveRowFormat>,
pub storage: Option<HiveIOFormat>,
@ -3475,7 +3483,7 @@ pub struct HiveFormat {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct SqlOption {
pub name: Ident,
pub value: Value,
@ -3489,7 +3497,7 @@ impl fmt::Display for SqlOption {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TransactionMode {
AccessMode(TransactionAccessMode),
IsolationLevel(TransactionIsolationLevel),
@ -3507,7 +3515,7 @@ impl fmt::Display for TransactionMode {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TransactionAccessMode {
ReadOnly,
ReadWrite,
@ -3525,7 +3533,7 @@ impl fmt::Display for TransactionAccessMode {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TransactionIsolationLevel {
ReadUncommitted,
ReadCommitted,
@ -3547,7 +3555,7 @@ impl fmt::Display for TransactionIsolationLevel {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ShowStatementFilter {
Like(String),
ILike(String),
@ -3571,7 +3579,7 @@ impl fmt::Display for ShowStatementFilter {
/// for more details.
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum SqliteOnConflict {
Rollback,
Abort,
@ -3595,7 +3603,7 @@ impl fmt::Display for SqliteOnConflict {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum CopyTarget {
Stdin,
Stdout,
@ -3627,7 +3635,7 @@ impl fmt::Display for CopyTarget {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum OnCommit {
DeleteRows,
PreserveRows,
@ -3639,7 +3647,7 @@ pub enum OnCommit {
/// <https://www.postgresql.org/docs/14/sql-copy.html>
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum CopyOption {
/// FORMAT format_name
Format(Ident),
@ -3693,7 +3701,7 @@ impl fmt::Display for CopyOption {
/// <https://www.postgresql.org/docs/8.4/sql-copy.html>
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum CopyLegacyOption {
/// BINARY
Binary,
@ -3722,7 +3730,7 @@ impl fmt::Display for CopyLegacyOption {
/// <https://www.postgresql.org/docs/8.4/sql-copy.html>
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum CopyLegacyCsvOption {
/// HEADER
Header,
@ -3754,7 +3762,7 @@ impl fmt::Display for CopyLegacyCsvOption {
///
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum MergeClause {
MatchedUpdate {
predicate: Option<Expr>,
@ -3816,7 +3824,7 @@ impl fmt::Display for MergeClause {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum DiscardObject {
ALL,
PLANS,
@ -3838,7 +3846,7 @@ impl fmt::Display for DiscardObject {
/// Optional context modifier for statements that can be or `LOCAL`, or `SESSION`.
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ContextModifier {
/// No context defined. Each dialect defines the default in this scenario.
None,
@ -3884,7 +3892,7 @@ impl fmt::Display for DropFunctionOption {
/// Function describe in DROP FUNCTION.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct DropFunctionDesc {
pub name: ObjectName,
pub args: Option<Vec<OperateFunctionArg>>,
@ -3903,7 +3911,7 @@ impl fmt::Display for DropFunctionDesc {
/// Function argument in CREATE OR DROP FUNCTION.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct OperateFunctionArg {
pub mode: Option<ArgMode>,
pub name: Option<Ident>,
@ -3952,7 +3960,7 @@ impl fmt::Display for OperateFunctionArg {
/// The mode of an argument in CREATE FUNCTION.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ArgMode {
In,
Out,
@ -3972,7 +3980,7 @@ impl fmt::Display for ArgMode {
/// These attributes inform the query optimizer about the behavior of the function.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FunctionBehavior {
Immutable,
Stable,
@ -3991,7 +3999,7 @@ impl fmt::Display for FunctionBehavior {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FunctionDefinition {
SingleQuotedDef(String),
DoubleDollarDef(String),
@ -4013,7 +4021,7 @@ impl fmt::Display for FunctionDefinition {
/// for more details
#[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CreateFunctionBody {
/// LANGUAGE lang_name
pub language: Option<Ident>,
@ -4052,7 +4060,7 @@ impl fmt::Display for CreateFunctionBody {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum CreateFunctionUsing {
Jar(String),
File(String),
@ -4075,7 +4083,7 @@ impl fmt::Display for CreateFunctionUsing {
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#schema-definition
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum SchemaName {
/// Only schema name specified: `<schema name>`.
Simple(ObjectName),
@ -4106,7 +4114,7 @@ impl fmt::Display for SchemaName {
/// [1]: https://dev.mysql.com/doc/refman/8.0/en/fulltext-search.html#function_match
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum SearchModifier {
/// `IN NATURAL LANGUAGE MODE`.
InNaturalLanguageMode,

View file

@ -19,14 +19,14 @@ use alloc::{string::String, vec::Vec};
use serde::{Deserialize, Serialize};
#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};
use super::display_separated;
/// Unary operators
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum UnaryOperator {
Plus,
Minus,
@ -64,7 +64,7 @@ impl fmt::Display for UnaryOperator {
/// Binary operators
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum BinaryOperator {
Plus,
Minus,

View file

@ -17,7 +17,7 @@ use alloc::{boxed::Box, vec::Vec};
use serde::{Deserialize, Serialize};
#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};
use crate::ast::*;
@ -25,7 +25,7 @@ use crate::ast::*;
/// including `WITH`, `UNION` / other set operations, and `ORDER BY`.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Query {
/// WITH (common table expressions, or CTEs)
pub with: Option<With>,
@ -73,7 +73,7 @@ impl fmt::Display for Query {
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum SetExpr {
/// Restricted SELECT .. FROM .. HAVING (no ORDER BY or set operations)
Select(Box<Select>),
@ -122,7 +122,7 @@ impl fmt::Display for SetExpr {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum SetOperator {
Union,
Except,
@ -144,7 +144,7 @@ impl fmt::Display for SetOperator {
// For example, BigQuery does not support `DISTINCT` for `EXCEPT` and `INTERSECT`
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum SetQuantifier {
All,
Distinct,
@ -164,7 +164,7 @@ impl fmt::Display for SetQuantifier {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
/// A [`TABLE` command]( https://www.postgresql.org/docs/current/sql-select.html#SQL-TABLE)
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Table {
pub table_name: Option<String>,
pub schema_name: Option<String>,
@ -191,7 +191,7 @@ impl fmt::Display for Table {
/// to a set operation like `UNION`.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Select {
pub distinct: bool,
/// MSSQL syntax: `TOP (<N>) [ PERCENT ] [ WITH TIES ]`
@ -276,7 +276,7 @@ impl fmt::Display for Select {
/// A hive LATERAL VIEW with potential column aliases
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct LateralView {
/// LATERAL VIEW
pub lateral_view: Expr,
@ -310,7 +310,7 @@ impl fmt::Display for LateralView {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct With {
pub recursive: bool,
pub cte_tables: Vec<Cte>,
@ -333,7 +333,7 @@ impl fmt::Display for With {
/// number of columns in the query matches the number of columns in the query.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Cte {
pub alias: TableAlias,
pub query: Box<Query>,
@ -353,7 +353,7 @@ impl fmt::Display for Cte {
/// One item of the comma-separated list following `SELECT`
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum SelectItem {
/// Any expression, not followed by `[ AS ] alias`
UnnamedExpr(Expr),
@ -368,7 +368,7 @@ pub enum SelectItem {
/// Additional options for wildcards, e.g. Snowflake `EXCLUDE` and Bigquery `EXCEPT`.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct WildcardAdditionalOptions {
/// `[EXCLUDE...]`.
pub opt_exclude: Option<ExcludeSelectItem>,
@ -397,7 +397,7 @@ impl fmt::Display for WildcardAdditionalOptions {
/// ```
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ExcludeSelectItem {
/// Single column name without parenthesis.
///
@ -437,7 +437,7 @@ impl fmt::Display for ExcludeSelectItem {
/// ```
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ExceptSelectItem {
/// First guaranteed column.
pub fist_elemnt: Ident,
@ -483,7 +483,7 @@ impl fmt::Display for SelectItem {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct TableWithJoins {
pub relation: TableFactor,
pub joins: Vec<Join>,
@ -502,7 +502,7 @@ impl fmt::Display for TableWithJoins {
/// A table name or a parenthesized subquery with an optional alias
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TableFactor {
Table {
#[cfg_attr(feature = "visitor", visit(with = "visit_relation"))]
@ -633,7 +633,7 @@ impl fmt::Display for TableFactor {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct TableAlias {
pub name: Ident,
pub columns: Vec<Ident>,
@ -651,7 +651,7 @@ impl fmt::Display for TableAlias {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Join {
pub relation: TableFactor,
pub join_operator: JoinOperator,
@ -746,7 +746,7 @@ impl fmt::Display for Join {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum JoinOperator {
Inner(JoinConstraint),
LeftOuter(JoinConstraint),
@ -769,7 +769,7 @@ pub enum JoinOperator {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum JoinConstraint {
On(Expr),
Using(Vec<Ident>),
@ -780,7 +780,7 @@ pub enum JoinConstraint {
/// An `ORDER BY` expression
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct OrderByExpr {
pub expr: Expr,
/// Optional `ASC` or `DESC`
@ -808,7 +808,7 @@ impl fmt::Display for OrderByExpr {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Offset {
pub value: Expr,
pub rows: OffsetRows,
@ -823,7 +823,7 @@ impl fmt::Display for Offset {
/// Stores the keyword after `OFFSET <number>`
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum OffsetRows {
/// Omitting ROW/ROWS is non-standard MySQL quirk.
None,
@ -843,7 +843,7 @@ impl fmt::Display for OffsetRows {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Fetch {
pub with_ties: bool,
pub percent: bool,
@ -864,7 +864,7 @@ impl fmt::Display for Fetch {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct LockClause {
pub lock_type: LockType,
pub of: Option<ObjectName>,
@ -886,7 +886,7 @@ impl fmt::Display for LockClause {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum LockType {
Share,
Update,
@ -904,7 +904,7 @@ impl fmt::Display for LockType {
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum NonBlock {
Nowait,
SkipLocked,
@ -922,7 +922,7 @@ impl fmt::Display for NonBlock {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Top {
/// SQL semantic equivalent of LIMIT but with same structure as FETCH.
pub with_ties: bool,
@ -944,7 +944,7 @@ impl fmt::Display for Top {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Values {
/// Was there an explict ROWs keyword (MySQL)?
/// <https://dev.mysql.com/doc/refman/8.0/en/values.html>
@ -968,7 +968,7 @@ impl fmt::Display for Values {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct SelectInto {
pub temporary: bool,
pub unlogged: bool,

View file

@ -21,12 +21,12 @@ use bigdecimal::BigDecimal;
use serde::{Deserialize, Serialize};
#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};
/// Primitive SQL values such as number and string
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum Value {
/// Numeric literal
#[cfg(not(feature = "bigdecimal"))]
@ -77,7 +77,7 @@ impl fmt::Display for Value {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct DollarQuotedString {
pub value: String,
pub tag: Option<String>,
@ -98,7 +98,7 @@ impl fmt::Display for DollarQuotedString {
#[derive(Debug, Copy, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum DateTimeField {
Year,
Month,
@ -229,7 +229,7 @@ pub fn escape_escaped_string(s: &str) -> EscapeEscapedStringLiteral<'_> {
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TrimWhereField {
Both,
Leading,

View file

@ -24,12 +24,27 @@ use core::ops::ControlFlow;
/// using the [Visit](sqlparser_derive::Visit) proc macro.
///
/// ```text
/// #[cfg_attr(feature = "visitor", derive(Visit))]
/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
/// ```
pub trait Visit {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
}
/// A type that can be visited by a [`VisitorMut`]. See [`VisitorMut`] for
/// recursively visiting parsed SQL statements.
///
/// # Note
///
/// This trait should be automatically derived for sqlparser AST nodes
/// using the [VisitMut](sqlparser_derive::VisitMut) proc macro.
///
/// ```text
/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
/// ```
pub trait VisitMut {
fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
}
impl<T: Visit> Visit for Option<T> {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
if let Some(s) = self {
@ -54,6 +69,30 @@ impl<T: Visit> Visit for Box<T> {
}
}
impl<T: VisitMut> VisitMut for Option<T> {
fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
if let Some(s) = self {
s.visit(visitor)?;
}
ControlFlow::Continue(())
}
}
impl<T: VisitMut> VisitMut for Vec<T> {
fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
for v in self {
v.visit(visitor)?;
}
ControlFlow::Continue(())
}
}
impl<T: VisitMut> VisitMut for Box<T> {
fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
T::visit(self, visitor)
}
}
macro_rules! visit_noop {
($($t:ty),+) => {
$(impl Visit for $t {
@ -61,6 +100,11 @@ macro_rules! visit_noop {
ControlFlow::Continue(())
}
})+
$(impl VisitMut for $t {
fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
ControlFlow::Continue(())
}
})+
};
}
@ -166,6 +210,84 @@ pub trait Visitor {
}
}
/// A visitor that can be used to mutate an AST tree.
///
/// `previst_` methods are invoked before visiting all children of the
/// node and `postvisit_` methods are invoked after visiting all
/// children of the node.
///
/// # See also
///
/// These methods provide a more concise way of visiting nodes of a certain type:
/// * [visit_relations_mut]
/// * [visit_expressions_mut]
/// * [visit_statements_mut]
///
/// # Example
/// ```
/// # use sqlparser::parser::Parser;
/// # use sqlparser::dialect::GenericDialect;
/// # use sqlparser::ast::{VisitMut, VisitorMut, ObjectName, Expr, Ident};
/// # use core::ops::ControlFlow;
///
/// // A visitor that replaces "to_replace" with "replaced" in all expressions
/// struct Replacer;
///
/// // Visit each expression after its children have been visited
/// impl VisitorMut for Replacer {
/// type Break = ();
///
/// fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
/// if let Expr::Identifier(Ident{ value, ..}) = expr {
/// *value = value.replace("to_replace", "replaced")
/// }
/// ControlFlow::Continue(())
/// }
/// }
///
/// let sql = "SELECT to_replace FROM foo where to_replace IN (SELECT to_replace FROM bar)";
/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
///
/// // Drive the visitor through the AST
/// statements.visit(&mut Replacer);
///
/// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)");
/// ```
pub trait VisitorMut {
/// Type returned when the recursion returns early.
type Break;
/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
/// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
fn post_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
/// Invoked for any expressions that appear in the AST before visiting children
fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
/// Invoked for any expressions that appear in the AST
fn post_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
/// Invoked for any statements that appear in the AST before visiting children
fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
/// Invoked for any statements that appear in the AST after visiting children
fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
}
struct RelationVisitor<F>(F);
impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F> {
@ -176,6 +298,14 @@ impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F>
}
}
impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
type Break = E;
fn pre_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
self.0(relation)
}
}
/// Invokes the provided closure on all relations (e.g. table names) present in `v`
///
/// # Example
@ -213,6 +343,36 @@ where
ControlFlow::Continue(())
}
/// Invokes the provided closure on all relations (e.g. table names) present in `v`
///
/// # Example
/// ```
/// # use sqlparser::parser::Parser;
/// # use sqlparser::dialect::GenericDialect;
/// # use sqlparser::ast::{ObjectName, visit_relations_mut};
/// # use core::ops::ControlFlow;
/// let sql = "SELECT a FROM foo";
/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql)
/// .unwrap();
///
/// // visit statements, renaming table foo to bar
/// visit_relations_mut(&mut statements, |table| {
/// table.0[0].value = table.0[0].value.replace("foo", "bar");
/// ControlFlow::<()>::Continue(())
/// });
///
/// assert_eq!(statements[0].to_string(), "SELECT a FROM bar");
/// ```
pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
where
V: VisitMut,
F: FnMut(&mut ObjectName) -> ControlFlow<E>,
{
let mut visitor = RelationVisitor(f);
v.visit(&mut visitor)?;
ControlFlow::Continue(())
}
struct ExprVisitor<F>(F);
impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
@ -223,6 +383,14 @@ impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
}
}
impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
type Break = E;
fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
self.0(expr)
}
}
/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
///
/// # Example
@ -262,6 +430,36 @@ where
ControlFlow::Continue(())
}
/// Invokes the provided closure on all expressions present in `v`
///
/// # Example
/// ```
/// # use sqlparser::parser::Parser;
/// # use sqlparser::dialect::GenericDialect;
/// # use sqlparser::ast::{Expr, visit_expressions_mut, visit_statements_mut};
/// # use core::ops::ControlFlow;
/// let sql = "SELECT (SELECT y FROM z LIMIT 9) FROM t LIMIT 3";
/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
///
/// // Remove all select limits in sub-queries
/// visit_expressions_mut(&mut statements, |expr| {
/// if let Expr::Subquery(q) = expr {
/// q.limit = None
/// }
/// ControlFlow::<()>::Continue(())
/// });
///
/// assert_eq!(statements[0].to_string(), "SELECT (SELECT y FROM z) FROM t LIMIT 3");
/// ```
pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
where
V: VisitMut,
F: FnMut(&mut Expr) -> ControlFlow<E>,
{
v.visit(&mut ExprVisitor(f))?;
ControlFlow::Continue(())
}
struct StatementVisitor<F>(F);
impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
@ -272,6 +470,14 @@ impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F>
}
}
impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
type Break = E;
fn pre_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
self.0(statement)
}
}
/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
///
/// # Example
@ -309,6 +515,37 @@ where
ControlFlow::Continue(())
}
/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
///
/// # Example
/// ```
/// # use sqlparser::parser::Parser;
/// # use sqlparser::dialect::GenericDialect;
/// # use sqlparser::ast::{Statement, visit_statements_mut};
/// # use core::ops::ControlFlow;
/// let sql = "SELECT x FROM foo LIMIT 9+$limit; SELECT * FROM t LIMIT f()";
/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
///
/// // Remove all select limits in outer statements (not in sub-queries)
/// visit_statements_mut(&mut statements, |stmt| {
/// if let Statement::Query(q) = stmt {
/// q.limit = None
/// }
/// ControlFlow::<()>::Continue(())
/// });
///
/// assert_eq!(statements[0].to_string(), "SELECT x FROM foo");
/// assert_eq!(statements[1].to_string(), "SELECT * FROM t");
/// ```
pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
where
V: VisitMut,
F: FnMut(&mut Statement) -> ControlFlow<E>,
{
v.visit(&mut StatementVisitor(f))?;
ControlFlow::Continue(())
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -26,7 +26,7 @@
use serde::{Deserialize, Serialize};
#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};
/// Defines a string constant for a single keyword: `kw_def!(SELECT);`
/// expands to `pub const SELECT = "SELECT";`
@ -47,7 +47,7 @@ macro_rules! define_keywords {
),*) => {
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
#[allow(non_camel_case_types)]
pub enum Keyword {
NoKeyword,

View file

@ -32,7 +32,7 @@ use core::str::Chars;
use serde::{Deserialize, Serialize};
#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};
use crate::ast::DollarQuotedString;
use crate::dialect::SnowflakeDialect;
@ -42,7 +42,7 @@ use crate::keywords::{Keyword, ALL_KEYWORDS, ALL_KEYWORDS_INDEX};
/// SQL Token enumeration
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum Token {
/// An end-of-file marker, not a real token
EOF,
@ -269,7 +269,7 @@ impl Token {
/// A keyword (like SELECT) or an optionally quoted SQL identifier
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Word {
/// The value of the token, without the enclosing quotes, and with the
/// escape sequences (if any) processed (TODO: escapes are not handled)
@ -308,7 +308,7 @@ impl Word {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum Whitespace {
Space,
Newline,