Add derive based AST visitor (#765)

* Add derive based AST visitor

* Fix BigDecimal

* Fix no visitor feature

* Add test

* Rename visit_table to visit_relation

* Review feedback

* Add pre and post visit

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
This commit is contained in:
Raphael Taylor-Davies 2022-12-28 15:07:12 +00:00 committed by GitHub
parent 3e990466f8
commit dec3c2b818
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 771 additions and 11 deletions

23
derive/Cargo.toml Normal file
View file

@ -0,0 +1,23 @@
[package]
name = "sqlparser_derive"
description = "proc macro for sqlparser"
version = "0.1.0"
authors = ["Andy Grove <andygrove73@gmail.com>"]
homepage = "https://github.com/sqlparser-rs/sqlparser-rs"
documentation = "https://docs.rs/sqlparser/"
keywords = ["ansi", "sql", "lexer", "parser"]
repository = "https://github.com/sqlparser-rs/sqlparser-rs"
license = "Apache-2.0"
include = [
"src/**/*.rs",
"Cargo.toml",
]
edition = "2021"
[lib]
proc-macro = true
[dependencies]
syn = "1.0"
proc-macro2 = "1.0"
quote = "1.0"

79
derive/README.md Normal file
View file

@ -0,0 +1,79 @@
# SQL Parser Derive Macro
## Visit
This crate contains a procedural macro that can automatically derive implementations of the `Visit` trait
```rust
#[derive(Visit)]
struct Foo {
boolean: bool,
bar: Bar,
}
#[derive(Visit)]
enum Bar {
A(),
B(String, bool),
C { named: i32 },
}
```
Will generate code akin to
```rust
impl Visit for Foo {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
self.boolean.visit(visitor)?;
self.bar.visit(visitor)?;
ControlFlow::Continue(())
}
}
impl Visit for Bar {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
match self {
Self::A() => {}
Self::B(_1, _2) => {
_1.visit(visitor)?;
_2.visit(visitor)?;
}
Self::C { named } => {
named.visit(visitor)?;
}
}
ControlFlow::Continue(())
}
}
```
Additionally certain types may wish to call a corresponding method on visitor before recursing
```rust
#[derive(Visit)]
#[visit(with = "visit_expr")]
enum Expr {
A(),
B(String, #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] ObjectName, bool),
}
```
Will generate
```rust
impl Visit for Bar {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
visitor.visit_expr(self)?;
match self {
Self::A() => {}
Self::B(_1, _2, _3) => {
_1.visit(visitor)?;
visitor.visit_relation(_3)?;
_2.visit(visitor)?;
_3.visit(visitor)?;
}
}
ControlFlow::Continue(())
}
}
```

184
derive/src/lib.rs Normal file
View file

@ -0,0 +1,184 @@
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{
parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics,
Ident, Index, Lit, Meta, MetaNameValue, NestedMeta,
};
/// Implementation of `[#derive(Visit)]`
#[proc_macro_derive(Visit, attributes(visit))]
pub fn derive_visit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
// Parse the input tokens into a syntax tree.
let input = parse_macro_input!(input as DeriveInput);
let name = input.ident;
let attributes = Attributes::parse(&input.attrs);
// Add a bound `T: HeapSize` to every type parameter T.
let generics = add_trait_bounds(input.generics);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let (pre_visit, post_visit) = attributes.visit(quote!(self));
let children = visit_children(&input.data);
let expanded = quote! {
// The generated impl.
impl #impl_generics sqlparser::ast::Visit for #name #ty_generics #where_clause {
fn visit<V: sqlparser::ast::Visitor>(&self, visitor: &mut V) -> ::std::ops::ControlFlow<V::Break> {
#pre_visit
#children
#post_visit
::std::ops::ControlFlow::Continue(())
}
}
};
proc_macro::TokenStream::from(expanded)
}
/// Parses attributes that can be provided to this macro
///
/// `#[visit(leaf, with = "visit_expr")]`
#[derive(Default)]
struct Attributes {
/// Content for the `with` attribute
with: Option<Ident>,
}
impl Attributes {
fn parse(attrs: &[Attribute]) -> Self {
let mut out = Self::default();
for attr in attrs.iter().filter(|a| a.path.is_ident("visit")) {
let meta = attr.parse_meta().expect("visit attribute");
match meta {
Meta::List(l) => {
for nested in &l.nested {
match nested {
NestedMeta::Meta(Meta::NameValue(v)) => out.parse_name_value(v),
_ => panic!("Expected #[visit(key = \"value\")]"),
}
}
}
_ => panic!("Expected #[visit(...)]"),
}
}
out
}
/// Updates self with a name value attribute
fn parse_name_value(&mut self, v: &MetaNameValue) {
if v.path.is_ident("with") {
match &v.lit {
Lit::Str(s) => self.with = Some(format_ident!("{}", s.value(), span = s.span())),
_ => panic!("Expected a string value, got {}", v.lit.to_token_stream()),
}
return;
}
panic!("Unrecognised kv attribute {}", v.path.to_token_stream())
}
/// Returns the pre and post visit token streams
fn visit(&self, s: TokenStream) -> (Option<TokenStream>, Option<TokenStream>) {
let pre_visit = self.with.as_ref().map(|m| {
let m = format_ident!("pre_{}", m);
quote!(visitor.#m(#s)?;)
});
let post_visit = self.with.as_ref().map(|m| {
let m = format_ident!("post_{}", m);
quote!(visitor.#m(#s)?;)
});
(pre_visit, post_visit)
}
}
// Add a bound `T: Visit` to every type parameter T.
fn add_trait_bounds(mut generics: Generics) -> Generics {
for param in &mut generics.params {
if let GenericParam::Type(ref mut type_param) = *param {
type_param.bounds.push(parse_quote!(sqlparser::ast::Visit));
}
}
generics
}
// Generate the body of the visit implementation for the given type
fn visit_children(data: &Data) -> TokenStream {
match data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => {
let recurse = fields.named.iter().map(|f| {
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#name, visitor)?; #post_visit)
});
quote! {
#(#recurse)*
}
}
Fields::Unnamed(fields) => {
let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
let index = Index::from(i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#index, visitor)?; #post_visit)
});
quote! {
#(#recurse)*
}
}
Fields::Unit => {
quote!()
}
},
Data::Enum(data) => {
let statements = data.variants.iter().map(|v| {
let name = &v.ident;
match &v.fields {
Fields::Named(fields) => {
let names = fields.named.iter().map(|f| &f.ident);
let visit = fields.named.iter().map(|f| {
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
});
quote!(
Self::#name { #(#names),* } => {
#(#visit)*
}
)
}
Fields::Unnamed(fields) => {
let names = fields.unnamed.iter().enumerate().map(|(i, f)| format_ident!("_{}", i, span = f.span()));
let visit = fields.unnamed.iter().enumerate().map(|(i, f)| {
let name = format_ident!("_{}", i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
});
quote! {
Self::#name ( #(#names),*) => {
#(#visit)*
}
}
}
Fields::Unit => {
quote! {
Self::#name => {}
}
}
}
});
quote! {
match self {
#(#statements),*
}
}
}
Data::Union(_) => unimplemented!(),
}
}