internal: vendor query-group-macro

This commit is contained in:
David Barsky 2025-03-06 16:00:08 -05:00
parent bd0289e0e9
commit 7a7ff470ca
17 changed files with 1989 additions and 22 deletions

View file

@ -0,0 +1,437 @@
//! A macro that mimics the old Salsa-style `#[query_group]` macro.
use core::fmt;
use std::vec;
use proc_macro::TokenStream;
use proc_macro2::Span;
use queries::{
GeneratedInputStruct, InputQuery, InputSetter, InputSetterWithDurability, Intern, Lookup,
Queries, SetterKind, TrackedQuery, Transparent,
};
use quote::{format_ident, quote, ToTokens};
use syn::spanned::Spanned;
use syn::visit_mut::VisitMut;
use syn::{parse_quote, Attribute, FnArg, ItemTrait, Path, TraitItem, TraitItemFn};
mod queries;
#[proc_macro_attribute]
pub fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
match query_group_impl(args, input.clone()) {
Ok(tokens) => tokens,
Err(e) => token_stream_with_error(input, e),
}
}
#[derive(Debug)]
struct InputStructField {
name: proc_macro2::TokenStream,
ty: proc_macro2::TokenStream,
}
impl fmt::Display for InputStructField {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
}
struct SalsaAttr {
name: String,
tts: TokenStream,
span: Span,
}
impl std::fmt::Debug for SalsaAttr {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(fmt, "{:?}", self.name)
}
}
impl TryFrom<syn::Attribute> for SalsaAttr {
type Error = syn::Attribute;
fn try_from(attr: syn::Attribute) -> Result<SalsaAttr, syn::Attribute> {
if is_not_salsa_attr_path(attr.path()) {
return Err(attr);
}
let span = attr.span();
let name = attr.path().segments[1].ident.to_string();
let tts = match attr.meta {
syn::Meta::Path(path) => path.into_token_stream(),
syn::Meta::List(ref list) => {
let tts = list
.into_token_stream()
.into_iter()
.skip(attr.path().to_token_stream().into_iter().count());
proc_macro2::TokenStream::from_iter(tts)
}
syn::Meta::NameValue(nv) => nv.into_token_stream(),
}
.into();
Ok(SalsaAttr { name, tts, span })
}
}
fn is_not_salsa_attr_path(path: &syn::Path) -> bool {
path.segments.first().map(|s| s.ident != "salsa").unwrap_or(true) || path.segments.len() != 2
}
fn filter_attrs(attrs: Vec<Attribute>) -> (Vec<Attribute>, Vec<SalsaAttr>) {
let mut other = vec![];
let mut salsa = vec![];
// Leave non-salsa attributes untouched. These are
// attributes that don't start with `salsa::` or don't have
// exactly two segments in their path.
for attr in attrs {
match SalsaAttr::try_from(attr) {
Ok(it) => salsa.push(it),
Err(it) => other.push(it),
}
}
(other, salsa)
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum QueryKind {
Input,
Tracked,
TrackedWithSalsaStruct,
Transparent,
Interned,
}
pub(crate) fn query_group_impl(
_args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> Result<proc_macro::TokenStream, syn::Error> {
let mut item_trait = syn::parse::<ItemTrait>(input)?;
let supertraits = &item_trait.supertraits;
let db_attr: Attribute = parse_quote! {
#[salsa::db]
};
item_trait.attrs.push(db_attr);
let trait_name_ident = &item_trait.ident.clone();
let input_struct_name = format_ident!("{}Data", trait_name_ident);
let create_data_ident = format_ident!("create_data_{}", trait_name_ident);
let mut input_struct_fields: Vec<InputStructField> = vec![];
let mut trait_methods = vec![];
let mut setter_trait_methods = vec![];
let mut lookup_signatures = vec![];
let mut lookup_methods = vec![];
for item in item_trait.clone().items {
if let syn::TraitItem::Fn(method) = item {
let method_name = &method.sig.ident;
let signature = &method.sig.clone();
let (_attrs, salsa_attrs) = filter_attrs(method.attrs);
let mut query_kind = QueryKind::Tracked;
let mut invoke = None;
let mut cycle = None;
let mut interned_struct_path = None;
let mut lru = None;
let params: Vec<FnArg> = signature.inputs.clone().into_iter().collect();
let pat_and_tys = params
.into_iter()
.filter(|fn_arg| matches!(fn_arg, FnArg::Typed(_)))
.map(|fn_arg| match fn_arg {
FnArg::Typed(pat_type) => pat_type.clone(),
FnArg::Receiver(_) => unreachable!("this should have been filtered out"),
})
.collect::<Vec<syn::PatType>>();
for SalsaAttr { name, tts, span } in salsa_attrs {
match name.as_str() {
"cycle" => {
let path = syn::parse::<Parenthesized<Path>>(tts)?;
cycle = Some(path.0.clone())
}
"input" => {
if !pat_and_tys.is_empty() {
return Err(syn::Error::new(
span,
"input methods cannot have a parameter",
));
}
query_kind = QueryKind::Input;
}
"interned" => {
let syn::ReturnType::Type(_, ty) = &signature.output else {
return Err(syn::Error::new(
span,
"interned queries must have return type",
));
};
let syn::Type::Path(path) = &**ty else {
return Err(syn::Error::new(
span,
"interned queries must have return type",
));
};
interned_struct_path = Some(path.path.clone());
query_kind = QueryKind::Interned;
}
"invoke" => {
let path = syn::parse::<Parenthesized<Path>>(tts)?;
invoke = Some(path.0.clone());
}
"invoke_actual" => {
let path = syn::parse::<Parenthesized<Path>>(tts)?;
invoke = Some(path.0.clone());
query_kind = QueryKind::TrackedWithSalsaStruct;
}
"lru" => {
let lru_count = syn::parse::<Parenthesized<syn::LitInt>>(tts)?;
let lru_count = lru_count.0.base10_parse::<u32>()?;
lru = Some(lru_count);
}
"transparent" => {
query_kind = QueryKind::Transparent;
}
_ => return Err(syn::Error::new(span, format!("unknown attribute `{name}`"))),
}
}
let syn::ReturnType::Type(_, return_ty) = signature.output.clone() else {
return Err(syn::Error::new(signature.span(), "Queries must have a return type"));
};
if let syn::Type::Path(ref ty_path) = *return_ty {
if matches!(query_kind, QueryKind::Input) {
let field = InputStructField {
name: method_name.to_token_stream(),
ty: ty_path.path.to_token_stream(),
};
input_struct_fields.push(field);
}
}
match (query_kind, invoke) {
// input
(QueryKind::Input, None) => {
let query = InputQuery {
signature: method.sig.clone(),
create_data_ident: create_data_ident.clone(),
};
let value = Queries::InputQuery(query);
trait_methods.push(value);
let setter = InputSetter {
signature: method.sig.clone(),
return_type: *return_ty.clone(),
create_data_ident: create_data_ident.clone(),
};
setter_trait_methods.push(SetterKind::Plain(setter));
let setter = InputSetterWithDurability {
signature: method.sig.clone(),
return_type: *return_ty.clone(),
create_data_ident: create_data_ident.clone(),
};
setter_trait_methods.push(SetterKind::WithDurability(setter));
}
(QueryKind::Interned, None) => {
let interned_struct_path = interned_struct_path.unwrap();
let method = Intern {
signature: signature.clone(),
pat_and_tys: pat_and_tys.clone(),
interned_struct_path: interned_struct_path.clone(),
};
trait_methods.push(Queries::Intern(method));
let mut method = Lookup {
signature: signature.clone(),
pat_and_tys: pat_and_tys.clone(),
return_ty: *return_ty,
interned_struct_path,
};
method.prepare_signature();
lookup_signatures
.push(TraitItem::Fn(make_trait_method(method.signature.clone())));
lookup_methods.push(method);
}
// tracked function. it might have an invoke, or might not.
(QueryKind::Tracked, invoke) => {
let method = TrackedQuery {
trait_name: trait_name_ident.clone(),
generated_struct: Some(GeneratedInputStruct {
input_struct_name: input_struct_name.clone(),
create_data_ident: create_data_ident.clone(),
}),
signature: signature.clone(),
pat_and_tys: pat_and_tys.clone(),
invoke,
cycle,
lru,
};
trait_methods.push(Queries::TrackedQuery(method));
}
(QueryKind::TrackedWithSalsaStruct, Some(invoke)) => {
let method = TrackedQuery {
trait_name: trait_name_ident.clone(),
generated_struct: None,
signature: signature.clone(),
pat_and_tys: pat_and_tys.clone(),
invoke: Some(invoke),
cycle,
lru,
};
trait_methods.push(Queries::TrackedQuery(method))
}
// while it is possible to make this reachable, it's not really worthwhile for a migration aid.
// doing this would require attaching an attribute to the salsa struct parameter in the query.
(QueryKind::TrackedWithSalsaStruct, None) => unreachable!(),
(QueryKind::Transparent, invoke) => {
let method = Transparent {
signature: method.sig.clone(),
pat_and_tys: pat_and_tys.clone(),
invoke,
};
trait_methods.push(Queries::Transparent(method));
}
// error/invalid constructions
(QueryKind::Interned, Some(path)) => {
return Err(syn::Error::new(
path.span(),
"Interned queries cannot be used with an `#[invoke]`".to_string(),
))
}
(QueryKind::Input, Some(path)) => {
return Err(syn::Error::new(
path.span(),
"Inputs cannot be used with an `#[invoke]`".to_string(),
))
}
}
}
}
let fields = input_struct_fields
.into_iter()
.map(|input| {
let name = input.name;
let ret = input.ty;
quote! { #name: Option<#ret> }
})
.collect::<Vec<proc_macro2::TokenStream>>();
let input_struct = quote! {
#[salsa::input]
pub(crate) struct #input_struct_name {
#(#fields),*
}
};
let field_params = std::iter::repeat_n(quote! { None }, fields.len())
.collect::<Vec<proc_macro2::TokenStream>>();
let create_data_method = quote! {
#[allow(non_snake_case)]
#[salsa::tracked]
fn #create_data_ident(db: &dyn #trait_name_ident) -> #input_struct_name {
#input_struct_name::new(db, #(#field_params),*)
}
};
let mut setter_signatures = vec![];
let mut setter_methods = vec![];
for trait_item in setter_trait_methods
.iter()
.map(|method| method.to_token_stream())
.map(|tokens| syn::parse2::<syn::TraitItemFn>(tokens).unwrap())
{
let mut methods_sans_body = trait_item.clone();
methods_sans_body.default = None;
methods_sans_body.semi_token = Some(syn::Token![;](trait_item.span()));
setter_signatures.push(TraitItem::Fn(methods_sans_body));
setter_methods.push(TraitItem::Fn(trait_item));
}
item_trait.items.append(&mut setter_signatures);
item_trait.items.append(&mut lookup_signatures);
let trait_impl = quote! {
#[salsa::db]
impl<DB> #trait_name_ident for DB
where
DB: #supertraits,
{
#(#trait_methods)*
#(#setter_methods)*
#(#lookup_methods)*
}
};
RemoveAttrsFromTraitMethods.visit_item_trait_mut(&mut item_trait);
let out = quote! {
#item_trait
#trait_impl
#input_struct
#create_data_method
}
.into();
Ok(out)
}
/// Parenthesis helper
pub(crate) struct Parenthesized<T>(pub(crate) T);
impl<T> syn::parse::Parse for Parenthesized<T>
where
T: syn::parse::Parse,
{
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
let content;
syn::parenthesized!(content in input);
content.parse::<T>().map(Parenthesized)
}
}
fn make_trait_method(sig: syn::Signature) -> TraitItemFn {
TraitItemFn {
attrs: vec![],
sig: sig.clone(),
semi_token: Some(syn::Token![;](sig.span())),
default: None,
}
}
struct RemoveAttrsFromTraitMethods;
impl VisitMut for RemoveAttrsFromTraitMethods {
fn visit_item_trait_mut(&mut self, i: &mut syn::ItemTrait) {
for item in &mut i.items {
if let TraitItem::Fn(trait_item_fn) = item {
trait_item_fn.attrs = vec![];
}
}
}
}
pub(crate) fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
tokens.extend(TokenStream::from(error.into_compile_error()));
tokens
}

View file

@ -0,0 +1,329 @@
//! The IR of the `#[query_group]` macro.
use quote::{format_ident, quote, ToTokens};
use syn::{parse_quote, FnArg, Ident, PatType, Path, Receiver, ReturnType, Type};
pub(crate) struct TrackedQuery {
pub(crate) trait_name: Ident,
pub(crate) signature: syn::Signature,
pub(crate) pat_and_tys: Vec<PatType>,
pub(crate) invoke: Option<Path>,
pub(crate) cycle: Option<Path>,
pub(crate) lru: Option<u32>,
pub(crate) generated_struct: Option<GeneratedInputStruct>,
}
pub(crate) struct GeneratedInputStruct {
pub(crate) input_struct_name: Ident,
pub(crate) create_data_ident: Ident,
}
impl ToTokens for TrackedQuery {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let sig = &self.signature;
let trait_name = &self.trait_name;
let ret = &sig.output;
let invoke = match &self.invoke {
Some(path) => path.to_token_stream(),
None => sig.ident.to_token_stream(),
};
let fn_ident = &sig.ident;
let shim: Ident = format_ident!("{}_shim", fn_ident);
let annotation = match (self.cycle.clone(), self.lru) {
(Some(cycle), Some(lru)) => quote!(#[salsa::tracked(lru = #lru, recovery_fn = #cycle)]),
(Some(cycle), None) => quote!(#[salsa::tracked(recovery_fn = #cycle)]),
(None, Some(lru)) => quote!(#[salsa::tracked(lru = #lru)]),
(None, None) => quote!(#[salsa::tracked]),
};
let pat_and_tys = &self.pat_and_tys;
let params = self
.pat_and_tys
.iter()
.map(|pat_type| pat_type.pat.clone())
.collect::<Vec<Box<syn::Pat>>>();
let method = match &self.generated_struct {
Some(generated_struct) => {
let input_struct_name = &generated_struct.input_struct_name;
let create_data_ident = &generated_struct.create_data_ident;
quote! {
#sig {
#annotation
fn #shim(
db: &dyn #trait_name,
_input: #input_struct_name,
#(#pat_and_tys),*
) #ret {
#invoke(db, #(#params),*)
}
#shim(self, #create_data_ident(self), #(#params),*)
}
}
}
None => {
quote! {
#sig {
#annotation
fn #shim(
db: &dyn #trait_name,
#(#pat_and_tys),*
) #ret {
#invoke(db, #(#params),*)
}
#shim(self, #(#params),*)
}
}
}
};
method.to_tokens(tokens);
}
}
pub(crate) struct InputQuery {
pub(crate) signature: syn::Signature,
pub(crate) create_data_ident: Ident,
}
impl ToTokens for InputQuery {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let sig = &self.signature;
let fn_ident = &sig.ident;
let create_data_ident = &self.create_data_ident;
let method = quote! {
#sig {
let data = #create_data_ident(self);
data.#fn_ident(self).unwrap()
}
};
method.to_tokens(tokens);
}
}
pub(crate) struct InputSetter {
pub(crate) signature: syn::Signature,
pub(crate) return_type: syn::Type,
pub(crate) create_data_ident: Ident,
}
impl ToTokens for InputSetter {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let sig = &mut self.signature.clone();
let ty = &self.return_type;
let fn_ident = &sig.ident;
let create_data_ident = &self.create_data_ident;
let setter_ident = format_ident!("set_{}", fn_ident);
sig.ident = setter_ident.clone();
let value_argument: PatType = parse_quote!(__value: #ty);
sig.inputs.push(FnArg::Typed(value_argument.clone()));
// make `&self` `&mut self` instead.
let mut_receiver: Receiver = parse_quote!(&mut self);
if let Some(og) = sig.inputs.first_mut() {
*og = FnArg::Receiver(mut_receiver)
}
// remove the return value.
sig.output = ReturnType::Default;
let value = &value_argument.pat;
let method = quote! {
#sig {
use salsa::Setter;
let data = #create_data_ident(self);
data.#setter_ident(self).to(Some(#value));
}
};
method.to_tokens(tokens);
}
}
pub(crate) struct InputSetterWithDurability {
pub(crate) signature: syn::Signature,
pub(crate) return_type: syn::Type,
pub(crate) create_data_ident: Ident,
}
impl ToTokens for InputSetterWithDurability {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let sig = &mut self.signature.clone();
let ty = &self.return_type;
let fn_ident = &sig.ident;
let setter_ident = format_ident!("set_{}", fn_ident);
let create_data_ident = &self.create_data_ident;
sig.ident = format_ident!("set_{}_with_durability", fn_ident);
let value_argument: PatType = parse_quote!(__value: #ty);
sig.inputs.push(FnArg::Typed(value_argument.clone()));
let durability_argument: PatType = parse_quote!(durability: salsa::Durability);
sig.inputs.push(FnArg::Typed(durability_argument.clone()));
// make `&self` `&mut self` instead.
let mut_receiver: Receiver = parse_quote!(&mut self);
if let Some(og) = sig.inputs.first_mut() {
*og = FnArg::Receiver(mut_receiver)
}
// remove the return value.
sig.output = ReturnType::Default;
let value = &value_argument.pat;
let durability = &durability_argument.pat;
let method = quote! {
#sig {
use salsa::Setter;
let data = #create_data_ident(self);
data.#setter_ident(self)
.with_durability(#durability)
.to(Some(#value));
}
};
method.to_tokens(tokens);
}
}
pub(crate) enum SetterKind {
Plain(InputSetter),
WithDurability(InputSetterWithDurability),
}
impl ToTokens for SetterKind {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
match self {
SetterKind::Plain(input_setter) => input_setter.to_tokens(tokens),
SetterKind::WithDurability(input_setter_with_durability) => {
input_setter_with_durability.to_tokens(tokens)
}
}
}
}
pub(crate) struct Transparent {
pub(crate) signature: syn::Signature,
pub(crate) pat_and_tys: Vec<PatType>,
pub(crate) invoke: Option<Path>,
}
impl ToTokens for Transparent {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let sig = &self.signature;
let ty = self
.pat_and_tys
.iter()
.map(|pat_type| pat_type.pat.clone())
.collect::<Vec<Box<syn::Pat>>>();
let invoke = match &self.invoke {
Some(path) => path.to_token_stream(),
None => sig.ident.to_token_stream(),
};
let method = quote! {
#sig {
#invoke(self, #(#ty),*)
}
};
method.to_tokens(tokens);
}
}
pub(crate) struct Intern {
pub(crate) signature: syn::Signature,
pub(crate) pat_and_tys: Vec<PatType>,
pub(crate) interned_struct_path: Path,
}
impl ToTokens for Intern {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let sig = &self.signature;
let ty = self.pat_and_tys.to_vec();
let interned_pat = ty.first().expect("at least one pat; this is a bug");
let interned_pat = &interned_pat.pat;
let wrapper_struct = self.interned_struct_path.to_token_stream();
let method = quote! {
#sig {
#wrapper_struct::new(self, #interned_pat)
}
};
method.to_tokens(tokens);
}
}
pub(crate) struct Lookup {
pub(crate) signature: syn::Signature,
pub(crate) pat_and_tys: Vec<PatType>,
pub(crate) return_ty: Type,
pub(crate) interned_struct_path: Path,
}
impl Lookup {
pub(crate) fn prepare_signature(&mut self) {
let sig = &self.signature;
let ident = format_ident!("lookup_{}", sig.ident);
let ty = self.pat_and_tys.to_vec();
let interned_key = &self.return_ty;
let interned_pat = ty.first().expect("at least one pat; this is a bug");
let interned_return_ty = &interned_pat.ty;
self.signature = parse_quote!(
fn #ident(&self, id: #interned_key) -> #interned_return_ty
);
}
}
impl ToTokens for Lookup {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let sig = &self.signature;
let wrapper_struct = self.interned_struct_path.to_token_stream();
let method = quote! {
#sig {
#wrapper_struct::ingredient(self).data(self.as_dyn_database(), id.as_id()).0.clone()
}
};
method.to_tokens(tokens);
}
}
pub(crate) enum Queries {
TrackedQuery(TrackedQuery),
InputQuery(InputQuery),
Intern(Intern),
Transparent(Transparent),
}
impl ToTokens for Queries {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
match self {
Queries::TrackedQuery(tracked_query) => tracked_query.to_tokens(tokens),
Queries::InputQuery(input_query) => input_query.to_tokens(tokens),
Queries::Transparent(transparent) => transparent.to_tokens(tokens),
Queries::Intern(intern) => intern.to_tokens(tokens),
}
}
}