switch requires syntax to an attribute

This commit is contained in:
Aleksey Kladov 2019-05-21 18:49:18 +03:00
parent c816df7208
commit 6ea5413ef5
2 changed files with 17 additions and 40 deletions

View file

@ -5,23 +5,30 @@ use heck::CamelCase;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::ToTokens;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{
parse_macro_input, parse_quote, Attribute, FnArg, Ident, ItemTrait, Lit, MetaNameValue, Path,
parse_macro_input, parse_quote, Attribute, FnArg, Ident, ItemTrait, Path,
ReturnType, Token, TraitBound, TraitBoundModifier, TraitItem, Type, TypeParamBound,
};
/// Implementation for `[salsa::query_group]` decorator.
pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
let GroupDef {
group_struct,
requires,
} = parse_macro_input!(args as GroupDef);
let group_struct = parse_macro_input!(args as Ident);
let input: ItemTrait = parse_macro_input!(input as ItemTrait);
// println!("args: {:#?}", args);
// println!("input: {:#?}", input);
let (trait_attrs, salsa_attrs) = filter_attrs(input.attrs);
let mut requires: Punctuated<Path, Token![+]> = Punctuated::new();
for SalsaAttr { name, tts } in salsa_attrs {
match name.as_str() {
"requires" => {
requires.push(parse_macro_input!(tts as Parenthesized<syn::Path>).0);
}
_ => panic!("unknown salsa attribute `{}`", name),
}
}
let trait_vis = input.vis;
let trait_name = input.ident;
let _generics = input.generics.clone();
@ -296,10 +303,9 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
// Emit the trait itself.
let mut output = {
let attrs = &input.attrs;
let bounds = &input.supertraits;
quote! {
#(#attrs)*
#(#trait_attrs)*
#trait_vis trait #trait_name : #bounds {
#query_fn_declarations
}
@ -550,37 +556,6 @@ fn filter_attrs(attrs: Vec<Attribute>) -> (Vec<Attribute>, Vec<SalsaAttr>) {
(other, salsa)
}
#[derive(Debug)]
struct GroupDef {
group_struct: Ident,
requires: Punctuated<Path, Token![+]>,
}
impl Parse for GroupDef {
fn parse(input: ParseStream) -> syn::Result<GroupDef> {
let res = GroupDef {
group_struct: input.parse()?,
requires: {
if input.lookahead1().peek(Token![,]) {
input.parse::<Token![,]>()?;
let name_value: MetaNameValue = input.parse()?;
if name_value.ident != "requires" {
return Err(syn::Error::new_spanned(name_value, "invalid attribute"));
}
let str_lit = match name_value.lit {
Lit::Str(it) => it,
_ => return Err(syn::Error::new_spanned(name_value, "invalid attribute")),
};
str_lit.parse_with(Punctuated::<Path, Token![+]>::parse_separated_nonempty)?
} else {
Punctuated::new()
}
},
};
Ok(res)
}
}
#[derive(Debug)]
struct Query {
fn_name: Ident,

View file

@ -26,7 +26,9 @@ mod queries {
db.input(x)
}
#[salsa::query_group(PubGroupStorage, requires = "PrivGroupA + PrivGroupB")]
#[salsa::query_group(PubGroupStorage)]
#[salsa::requires(PrivGroupA)]
#[salsa::requires(PrivGroupB)]
pub trait PubGroup: InputGroup {
fn public(&self, x: u32) -> u32;
}