mirror of
https://github.com/rust-lang/rust-analyzer.git
synced 2025-08-04 18:58:41 +00:00
internal: vendor query-group-macro
This commit is contained in:
parent
bd0289e0e9
commit
7a7ff470ca
17 changed files with 1989 additions and 22 deletions
23
crates/query-group-macro/Cargo.toml
Normal file
23
crates/query-group-macro/Cargo.toml
Normal file
|
@ -0,0 +1,23 @@
|
|||
[package]
|
||||
name = "query-group-macro"
|
||||
version = "0.0.0"
|
||||
repository.workspace = true
|
||||
description = "A macro mimicking the `#[salsa::query_group]` macro for migrating to new Salsa"
|
||||
|
||||
authors.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
heck = "0.5.0"
|
||||
proc-macro2 = "1.0"
|
||||
quote = "1.0"
|
||||
syn = { version = "2.0", features = ["full", "extra-traits"] }
|
||||
salsa = { version = "0.19.0" }
|
||||
|
||||
[dev-dependencies]
|
||||
expect-test = "1.5.0"
|
437
crates/query-group-macro/src/lib.rs
Normal file
437
crates/query-group-macro/src/lib.rs
Normal file
|
@ -0,0 +1,437 @@
|
|||
//! A macro that mimics the old Salsa-style `#[query_group]` macro.
|
||||
|
||||
use core::fmt;
|
||||
use std::vec;
|
||||
|
||||
use proc_macro::TokenStream;
|
||||
use proc_macro2::Span;
|
||||
use queries::{
|
||||
GeneratedInputStruct, InputQuery, InputSetter, InputSetterWithDurability, Intern, Lookup,
|
||||
Queries, SetterKind, TrackedQuery, Transparent,
|
||||
};
|
||||
use quote::{format_ident, quote, ToTokens};
|
||||
use syn::spanned::Spanned;
|
||||
use syn::visit_mut::VisitMut;
|
||||
use syn::{parse_quote, Attribute, FnArg, ItemTrait, Path, TraitItem, TraitItemFn};
|
||||
|
||||
mod queries;
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||
match query_group_impl(args, input.clone()) {
|
||||
Ok(tokens) => tokens,
|
||||
Err(e) => token_stream_with_error(input, e),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InputStructField {
|
||||
name: proc_macro2::TokenStream,
|
||||
ty: proc_macro2::TokenStream,
|
||||
}
|
||||
|
||||
impl fmt::Display for InputStructField {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.name)
|
||||
}
|
||||
}
|
||||
|
||||
struct SalsaAttr {
|
||||
name: String,
|
||||
tts: TokenStream,
|
||||
span: Span,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SalsaAttr {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(fmt, "{:?}", self.name)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<syn::Attribute> for SalsaAttr {
|
||||
type Error = syn::Attribute;
|
||||
|
||||
fn try_from(attr: syn::Attribute) -> Result<SalsaAttr, syn::Attribute> {
|
||||
if is_not_salsa_attr_path(attr.path()) {
|
||||
return Err(attr);
|
||||
}
|
||||
|
||||
let span = attr.span();
|
||||
|
||||
let name = attr.path().segments[1].ident.to_string();
|
||||
let tts = match attr.meta {
|
||||
syn::Meta::Path(path) => path.into_token_stream(),
|
||||
syn::Meta::List(ref list) => {
|
||||
let tts = list
|
||||
.into_token_stream()
|
||||
.into_iter()
|
||||
.skip(attr.path().to_token_stream().into_iter().count());
|
||||
proc_macro2::TokenStream::from_iter(tts)
|
||||
}
|
||||
syn::Meta::NameValue(nv) => nv.into_token_stream(),
|
||||
}
|
||||
.into();
|
||||
|
||||
Ok(SalsaAttr { name, tts, span })
|
||||
}
|
||||
}
|
||||
|
||||
fn is_not_salsa_attr_path(path: &syn::Path) -> bool {
|
||||
path.segments.first().map(|s| s.ident != "salsa").unwrap_or(true) || path.segments.len() != 2
|
||||
}
|
||||
|
||||
fn filter_attrs(attrs: Vec<Attribute>) -> (Vec<Attribute>, Vec<SalsaAttr>) {
|
||||
let mut other = vec![];
|
||||
let mut salsa = vec![];
|
||||
// Leave non-salsa attributes untouched. These are
|
||||
// attributes that don't start with `salsa::` or don't have
|
||||
// exactly two segments in their path.
|
||||
for attr in attrs {
|
||||
match SalsaAttr::try_from(attr) {
|
||||
Ok(it) => salsa.push(it),
|
||||
Err(it) => other.push(it),
|
||||
}
|
||||
}
|
||||
(other, salsa)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum QueryKind {
|
||||
Input,
|
||||
Tracked,
|
||||
TrackedWithSalsaStruct,
|
||||
Transparent,
|
||||
Interned,
|
||||
}
|
||||
|
||||
pub(crate) fn query_group_impl(
|
||||
_args: proc_macro::TokenStream,
|
||||
input: proc_macro::TokenStream,
|
||||
) -> Result<proc_macro::TokenStream, syn::Error> {
|
||||
let mut item_trait = syn::parse::<ItemTrait>(input)?;
|
||||
|
||||
let supertraits = &item_trait.supertraits;
|
||||
|
||||
let db_attr: Attribute = parse_quote! {
|
||||
#[salsa::db]
|
||||
};
|
||||
item_trait.attrs.push(db_attr);
|
||||
|
||||
let trait_name_ident = &item_trait.ident.clone();
|
||||
let input_struct_name = format_ident!("{}Data", trait_name_ident);
|
||||
let create_data_ident = format_ident!("create_data_{}", trait_name_ident);
|
||||
|
||||
let mut input_struct_fields: Vec<InputStructField> = vec![];
|
||||
let mut trait_methods = vec![];
|
||||
let mut setter_trait_methods = vec![];
|
||||
let mut lookup_signatures = vec![];
|
||||
let mut lookup_methods = vec![];
|
||||
|
||||
for item in item_trait.clone().items {
|
||||
if let syn::TraitItem::Fn(method) = item {
|
||||
let method_name = &method.sig.ident;
|
||||
let signature = &method.sig.clone();
|
||||
|
||||
let (_attrs, salsa_attrs) = filter_attrs(method.attrs);
|
||||
|
||||
let mut query_kind = QueryKind::Tracked;
|
||||
let mut invoke = None;
|
||||
let mut cycle = None;
|
||||
let mut interned_struct_path = None;
|
||||
let mut lru = None;
|
||||
|
||||
let params: Vec<FnArg> = signature.inputs.clone().into_iter().collect();
|
||||
let pat_and_tys = params
|
||||
.into_iter()
|
||||
.filter(|fn_arg| matches!(fn_arg, FnArg::Typed(_)))
|
||||
.map(|fn_arg| match fn_arg {
|
||||
FnArg::Typed(pat_type) => pat_type.clone(),
|
||||
FnArg::Receiver(_) => unreachable!("this should have been filtered out"),
|
||||
})
|
||||
.collect::<Vec<syn::PatType>>();
|
||||
|
||||
for SalsaAttr { name, tts, span } in salsa_attrs {
|
||||
match name.as_str() {
|
||||
"cycle" => {
|
||||
let path = syn::parse::<Parenthesized<Path>>(tts)?;
|
||||
cycle = Some(path.0.clone())
|
||||
}
|
||||
"input" => {
|
||||
if !pat_and_tys.is_empty() {
|
||||
return Err(syn::Error::new(
|
||||
span,
|
||||
"input methods cannot have a parameter",
|
||||
));
|
||||
}
|
||||
query_kind = QueryKind::Input;
|
||||
}
|
||||
"interned" => {
|
||||
let syn::ReturnType::Type(_, ty) = &signature.output else {
|
||||
return Err(syn::Error::new(
|
||||
span,
|
||||
"interned queries must have return type",
|
||||
));
|
||||
};
|
||||
let syn::Type::Path(path) = &**ty else {
|
||||
return Err(syn::Error::new(
|
||||
span,
|
||||
"interned queries must have return type",
|
||||
));
|
||||
};
|
||||
interned_struct_path = Some(path.path.clone());
|
||||
query_kind = QueryKind::Interned;
|
||||
}
|
||||
"invoke" => {
|
||||
let path = syn::parse::<Parenthesized<Path>>(tts)?;
|
||||
invoke = Some(path.0.clone());
|
||||
}
|
||||
"invoke_actual" => {
|
||||
let path = syn::parse::<Parenthesized<Path>>(tts)?;
|
||||
invoke = Some(path.0.clone());
|
||||
query_kind = QueryKind::TrackedWithSalsaStruct;
|
||||
}
|
||||
"lru" => {
|
||||
let lru_count = syn::parse::<Parenthesized<syn::LitInt>>(tts)?;
|
||||
let lru_count = lru_count.0.base10_parse::<u32>()?;
|
||||
|
||||
lru = Some(lru_count);
|
||||
}
|
||||
"transparent" => {
|
||||
query_kind = QueryKind::Transparent;
|
||||
}
|
||||
_ => return Err(syn::Error::new(span, format!("unknown attribute `{name}`"))),
|
||||
}
|
||||
}
|
||||
|
||||
let syn::ReturnType::Type(_, return_ty) = signature.output.clone() else {
|
||||
return Err(syn::Error::new(signature.span(), "Queries must have a return type"));
|
||||
};
|
||||
|
||||
if let syn::Type::Path(ref ty_path) = *return_ty {
|
||||
if matches!(query_kind, QueryKind::Input) {
|
||||
let field = InputStructField {
|
||||
name: method_name.to_token_stream(),
|
||||
ty: ty_path.path.to_token_stream(),
|
||||
};
|
||||
|
||||
input_struct_fields.push(field);
|
||||
}
|
||||
}
|
||||
|
||||
match (query_kind, invoke) {
|
||||
// input
|
||||
(QueryKind::Input, None) => {
|
||||
let query = InputQuery {
|
||||
signature: method.sig.clone(),
|
||||
create_data_ident: create_data_ident.clone(),
|
||||
};
|
||||
let value = Queries::InputQuery(query);
|
||||
trait_methods.push(value);
|
||||
|
||||
let setter = InputSetter {
|
||||
signature: method.sig.clone(),
|
||||
return_type: *return_ty.clone(),
|
||||
create_data_ident: create_data_ident.clone(),
|
||||
};
|
||||
setter_trait_methods.push(SetterKind::Plain(setter));
|
||||
|
||||
let setter = InputSetterWithDurability {
|
||||
signature: method.sig.clone(),
|
||||
return_type: *return_ty.clone(),
|
||||
create_data_ident: create_data_ident.clone(),
|
||||
};
|
||||
setter_trait_methods.push(SetterKind::WithDurability(setter));
|
||||
}
|
||||
(QueryKind::Interned, None) => {
|
||||
let interned_struct_path = interned_struct_path.unwrap();
|
||||
let method = Intern {
|
||||
signature: signature.clone(),
|
||||
pat_and_tys: pat_and_tys.clone(),
|
||||
interned_struct_path: interned_struct_path.clone(),
|
||||
};
|
||||
|
||||
trait_methods.push(Queries::Intern(method));
|
||||
|
||||
let mut method = Lookup {
|
||||
signature: signature.clone(),
|
||||
pat_and_tys: pat_and_tys.clone(),
|
||||
return_ty: *return_ty,
|
||||
interned_struct_path,
|
||||
};
|
||||
method.prepare_signature();
|
||||
|
||||
lookup_signatures
|
||||
.push(TraitItem::Fn(make_trait_method(method.signature.clone())));
|
||||
lookup_methods.push(method);
|
||||
}
|
||||
// tracked function. it might have an invoke, or might not.
|
||||
(QueryKind::Tracked, invoke) => {
|
||||
let method = TrackedQuery {
|
||||
trait_name: trait_name_ident.clone(),
|
||||
generated_struct: Some(GeneratedInputStruct {
|
||||
input_struct_name: input_struct_name.clone(),
|
||||
create_data_ident: create_data_ident.clone(),
|
||||
}),
|
||||
signature: signature.clone(),
|
||||
pat_and_tys: pat_and_tys.clone(),
|
||||
invoke,
|
||||
cycle,
|
||||
lru,
|
||||
};
|
||||
|
||||
trait_methods.push(Queries::TrackedQuery(method));
|
||||
}
|
||||
(QueryKind::TrackedWithSalsaStruct, Some(invoke)) => {
|
||||
let method = TrackedQuery {
|
||||
trait_name: trait_name_ident.clone(),
|
||||
generated_struct: None,
|
||||
signature: signature.clone(),
|
||||
pat_and_tys: pat_and_tys.clone(),
|
||||
invoke: Some(invoke),
|
||||
cycle,
|
||||
lru,
|
||||
};
|
||||
|
||||
trait_methods.push(Queries::TrackedQuery(method))
|
||||
}
|
||||
// while it is possible to make this reachable, it's not really worthwhile for a migration aid.
|
||||
// doing this would require attaching an attribute to the salsa struct parameter in the query.
|
||||
(QueryKind::TrackedWithSalsaStruct, None) => unreachable!(),
|
||||
(QueryKind::Transparent, invoke) => {
|
||||
let method = Transparent {
|
||||
signature: method.sig.clone(),
|
||||
pat_and_tys: pat_and_tys.clone(),
|
||||
invoke,
|
||||
};
|
||||
trait_methods.push(Queries::Transparent(method));
|
||||
}
|
||||
// error/invalid constructions
|
||||
(QueryKind::Interned, Some(path)) => {
|
||||
return Err(syn::Error::new(
|
||||
path.span(),
|
||||
"Interned queries cannot be used with an `#[invoke]`".to_string(),
|
||||
))
|
||||
}
|
||||
(QueryKind::Input, Some(path)) => {
|
||||
return Err(syn::Error::new(
|
||||
path.span(),
|
||||
"Inputs cannot be used with an `#[invoke]`".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let fields = input_struct_fields
|
||||
.into_iter()
|
||||
.map(|input| {
|
||||
let name = input.name;
|
||||
let ret = input.ty;
|
||||
quote! { #name: Option<#ret> }
|
||||
})
|
||||
.collect::<Vec<proc_macro2::TokenStream>>();
|
||||
|
||||
let input_struct = quote! {
|
||||
#[salsa::input]
|
||||
pub(crate) struct #input_struct_name {
|
||||
#(#fields),*
|
||||
}
|
||||
};
|
||||
|
||||
let field_params = std::iter::repeat_n(quote! { None }, fields.len())
|
||||
.collect::<Vec<proc_macro2::TokenStream>>();
|
||||
|
||||
let create_data_method = quote! {
|
||||
#[allow(non_snake_case)]
|
||||
#[salsa::tracked]
|
||||
fn #create_data_ident(db: &dyn #trait_name_ident) -> #input_struct_name {
|
||||
#input_struct_name::new(db, #(#field_params),*)
|
||||
}
|
||||
};
|
||||
|
||||
let mut setter_signatures = vec![];
|
||||
let mut setter_methods = vec![];
|
||||
for trait_item in setter_trait_methods
|
||||
.iter()
|
||||
.map(|method| method.to_token_stream())
|
||||
.map(|tokens| syn::parse2::<syn::TraitItemFn>(tokens).unwrap())
|
||||
{
|
||||
let mut methods_sans_body = trait_item.clone();
|
||||
methods_sans_body.default = None;
|
||||
methods_sans_body.semi_token = Some(syn::Token));
|
||||
|
||||
setter_signatures.push(TraitItem::Fn(methods_sans_body));
|
||||
setter_methods.push(TraitItem::Fn(trait_item));
|
||||
}
|
||||
|
||||
item_trait.items.append(&mut setter_signatures);
|
||||
item_trait.items.append(&mut lookup_signatures);
|
||||
|
||||
let trait_impl = quote! {
|
||||
#[salsa::db]
|
||||
impl<DB> #trait_name_ident for DB
|
||||
where
|
||||
DB: #supertraits,
|
||||
{
|
||||
#(#trait_methods)*
|
||||
|
||||
#(#setter_methods)*
|
||||
|
||||
#(#lookup_methods)*
|
||||
}
|
||||
};
|
||||
RemoveAttrsFromTraitMethods.visit_item_trait_mut(&mut item_trait);
|
||||
|
||||
let out = quote! {
|
||||
#item_trait
|
||||
|
||||
#trait_impl
|
||||
|
||||
#input_struct
|
||||
|
||||
#create_data_method
|
||||
}
|
||||
.into();
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Parenthesis helper
|
||||
pub(crate) struct Parenthesized<T>(pub(crate) T);
|
||||
|
||||
impl<T> syn::parse::Parse for Parenthesized<T>
|
||||
where
|
||||
T: syn::parse::Parse,
|
||||
{
|
||||
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
|
||||
let content;
|
||||
syn::parenthesized!(content in input);
|
||||
content.parse::<T>().map(Parenthesized)
|
||||
}
|
||||
}
|
||||
|
||||
fn make_trait_method(sig: syn::Signature) -> TraitItemFn {
|
||||
TraitItemFn {
|
||||
attrs: vec![],
|
||||
sig: sig.clone(),
|
||||
semi_token: Some(syn::Token)),
|
||||
default: None,
|
||||
}
|
||||
}
|
||||
|
||||
struct RemoveAttrsFromTraitMethods;
|
||||
|
||||
impl VisitMut for RemoveAttrsFromTraitMethods {
|
||||
fn visit_item_trait_mut(&mut self, i: &mut syn::ItemTrait) {
|
||||
for item in &mut i.items {
|
||||
if let TraitItem::Fn(trait_item_fn) = item {
|
||||
trait_item_fn.attrs = vec![];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
|
||||
tokens.extend(TokenStream::from(error.into_compile_error()));
|
||||
tokens
|
||||
}
|
329
crates/query-group-macro/src/queries.rs
Normal file
329
crates/query-group-macro/src/queries.rs
Normal file
|
@ -0,0 +1,329 @@
|
|||
//! The IR of the `#[query_group]` macro.
|
||||
|
||||
use quote::{format_ident, quote, ToTokens};
|
||||
use syn::{parse_quote, FnArg, Ident, PatType, Path, Receiver, ReturnType, Type};
|
||||
|
||||
pub(crate) struct TrackedQuery {
|
||||
pub(crate) trait_name: Ident,
|
||||
pub(crate) signature: syn::Signature,
|
||||
pub(crate) pat_and_tys: Vec<PatType>,
|
||||
pub(crate) invoke: Option<Path>,
|
||||
pub(crate) cycle: Option<Path>,
|
||||
pub(crate) lru: Option<u32>,
|
||||
pub(crate) generated_struct: Option<GeneratedInputStruct>,
|
||||
}
|
||||
|
||||
pub(crate) struct GeneratedInputStruct {
|
||||
pub(crate) input_struct_name: Ident,
|
||||
pub(crate) create_data_ident: Ident,
|
||||
}
|
||||
|
||||
impl ToTokens for TrackedQuery {
|
||||
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
|
||||
let sig = &self.signature;
|
||||
let trait_name = &self.trait_name;
|
||||
|
||||
let ret = &sig.output;
|
||||
|
||||
let invoke = match &self.invoke {
|
||||
Some(path) => path.to_token_stream(),
|
||||
None => sig.ident.to_token_stream(),
|
||||
};
|
||||
|
||||
let fn_ident = &sig.ident;
|
||||
let shim: Ident = format_ident!("{}_shim", fn_ident);
|
||||
|
||||
let annotation = match (self.cycle.clone(), self.lru) {
|
||||
(Some(cycle), Some(lru)) => quote!(#[salsa::tracked(lru = #lru, recovery_fn = #cycle)]),
|
||||
(Some(cycle), None) => quote!(#[salsa::tracked(recovery_fn = #cycle)]),
|
||||
(None, Some(lru)) => quote!(#[salsa::tracked(lru = #lru)]),
|
||||
(None, None) => quote!(#[salsa::tracked]),
|
||||
};
|
||||
|
||||
let pat_and_tys = &self.pat_and_tys;
|
||||
let params = self
|
||||
.pat_and_tys
|
||||
.iter()
|
||||
.map(|pat_type| pat_type.pat.clone())
|
||||
.collect::<Vec<Box<syn::Pat>>>();
|
||||
|
||||
let method = match &self.generated_struct {
|
||||
Some(generated_struct) => {
|
||||
let input_struct_name = &generated_struct.input_struct_name;
|
||||
let create_data_ident = &generated_struct.create_data_ident;
|
||||
|
||||
quote! {
|
||||
#sig {
|
||||
#annotation
|
||||
fn #shim(
|
||||
db: &dyn #trait_name,
|
||||
_input: #input_struct_name,
|
||||
#(#pat_and_tys),*
|
||||
) #ret {
|
||||
#invoke(db, #(#params),*)
|
||||
}
|
||||
#shim(self, #create_data_ident(self), #(#params),*)
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
quote! {
|
||||
#sig {
|
||||
#annotation
|
||||
fn #shim(
|
||||
db: &dyn #trait_name,
|
||||
#(#pat_and_tys),*
|
||||
) #ret {
|
||||
#invoke(db, #(#params),*)
|
||||
}
|
||||
#shim(self, #(#params),*)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
method.to_tokens(tokens);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct InputQuery {
|
||||
pub(crate) signature: syn::Signature,
|
||||
pub(crate) create_data_ident: Ident,
|
||||
}
|
||||
|
||||
impl ToTokens for InputQuery {
|
||||
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
|
||||
let sig = &self.signature;
|
||||
let fn_ident = &sig.ident;
|
||||
let create_data_ident = &self.create_data_ident;
|
||||
|
||||
let method = quote! {
|
||||
#sig {
|
||||
let data = #create_data_ident(self);
|
||||
data.#fn_ident(self).unwrap()
|
||||
}
|
||||
};
|
||||
method.to_tokens(tokens);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct InputSetter {
|
||||
pub(crate) signature: syn::Signature,
|
||||
pub(crate) return_type: syn::Type,
|
||||
pub(crate) create_data_ident: Ident,
|
||||
}
|
||||
|
||||
impl ToTokens for InputSetter {
|
||||
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
|
||||
let sig = &mut self.signature.clone();
|
||||
|
||||
let ty = &self.return_type;
|
||||
let fn_ident = &sig.ident;
|
||||
let create_data_ident = &self.create_data_ident;
|
||||
|
||||
let setter_ident = format_ident!("set_{}", fn_ident);
|
||||
sig.ident = setter_ident.clone();
|
||||
|
||||
let value_argument: PatType = parse_quote!(__value: #ty);
|
||||
sig.inputs.push(FnArg::Typed(value_argument.clone()));
|
||||
|
||||
// make `&self` `&mut self` instead.
|
||||
let mut_receiver: Receiver = parse_quote!(&mut self);
|
||||
if let Some(og) = sig.inputs.first_mut() {
|
||||
*og = FnArg::Receiver(mut_receiver)
|
||||
}
|
||||
|
||||
// remove the return value.
|
||||
sig.output = ReturnType::Default;
|
||||
|
||||
let value = &value_argument.pat;
|
||||
let method = quote! {
|
||||
#sig {
|
||||
use salsa::Setter;
|
||||
let data = #create_data_ident(self);
|
||||
data.#setter_ident(self).to(Some(#value));
|
||||
}
|
||||
};
|
||||
method.to_tokens(tokens);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct InputSetterWithDurability {
|
||||
pub(crate) signature: syn::Signature,
|
||||
pub(crate) return_type: syn::Type,
|
||||
pub(crate) create_data_ident: Ident,
|
||||
}
|
||||
|
||||
impl ToTokens for InputSetterWithDurability {
|
||||
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
|
||||
let sig = &mut self.signature.clone();
|
||||
|
||||
let ty = &self.return_type;
|
||||
let fn_ident = &sig.ident;
|
||||
let setter_ident = format_ident!("set_{}", fn_ident);
|
||||
|
||||
let create_data_ident = &self.create_data_ident;
|
||||
|
||||
sig.ident = format_ident!("set_{}_with_durability", fn_ident);
|
||||
|
||||
let value_argument: PatType = parse_quote!(__value: #ty);
|
||||
sig.inputs.push(FnArg::Typed(value_argument.clone()));
|
||||
|
||||
let durability_argument: PatType = parse_quote!(durability: salsa::Durability);
|
||||
sig.inputs.push(FnArg::Typed(durability_argument.clone()));
|
||||
|
||||
// make `&self` `&mut self` instead.
|
||||
let mut_receiver: Receiver = parse_quote!(&mut self);
|
||||
if let Some(og) = sig.inputs.first_mut() {
|
||||
*og = FnArg::Receiver(mut_receiver)
|
||||
}
|
||||
|
||||
// remove the return value.
|
||||
sig.output = ReturnType::Default;
|
||||
|
||||
let value = &value_argument.pat;
|
||||
let durability = &durability_argument.pat;
|
||||
let method = quote! {
|
||||
#sig {
|
||||
use salsa::Setter;
|
||||
let data = #create_data_ident(self);
|
||||
data.#setter_ident(self)
|
||||
.with_durability(#durability)
|
||||
.to(Some(#value));
|
||||
}
|
||||
};
|
||||
method.to_tokens(tokens);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum SetterKind {
|
||||
Plain(InputSetter),
|
||||
WithDurability(InputSetterWithDurability),
|
||||
}
|
||||
|
||||
impl ToTokens for SetterKind {
|
||||
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
|
||||
match self {
|
||||
SetterKind::Plain(input_setter) => input_setter.to_tokens(tokens),
|
||||
SetterKind::WithDurability(input_setter_with_durability) => {
|
||||
input_setter_with_durability.to_tokens(tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Transparent {
|
||||
pub(crate) signature: syn::Signature,
|
||||
pub(crate) pat_and_tys: Vec<PatType>,
|
||||
pub(crate) invoke: Option<Path>,
|
||||
}
|
||||
|
||||
impl ToTokens for Transparent {
|
||||
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
|
||||
let sig = &self.signature;
|
||||
|
||||
let ty = self
|
||||
.pat_and_tys
|
||||
.iter()
|
||||
.map(|pat_type| pat_type.pat.clone())
|
||||
.collect::<Vec<Box<syn::Pat>>>();
|
||||
|
||||
let invoke = match &self.invoke {
|
||||
Some(path) => path.to_token_stream(),
|
||||
None => sig.ident.to_token_stream(),
|
||||
};
|
||||
|
||||
let method = quote! {
|
||||
#sig {
|
||||
#invoke(self, #(#ty),*)
|
||||
}
|
||||
};
|
||||
|
||||
method.to_tokens(tokens);
|
||||
}
|
||||
}
|
||||
pub(crate) struct Intern {
|
||||
pub(crate) signature: syn::Signature,
|
||||
pub(crate) pat_and_tys: Vec<PatType>,
|
||||
pub(crate) interned_struct_path: Path,
|
||||
}
|
||||
|
||||
impl ToTokens for Intern {
|
||||
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
|
||||
let sig = &self.signature;
|
||||
|
||||
let ty = self.pat_and_tys.to_vec();
|
||||
|
||||
let interned_pat = ty.first().expect("at least one pat; this is a bug");
|
||||
let interned_pat = &interned_pat.pat;
|
||||
|
||||
let wrapper_struct = self.interned_struct_path.to_token_stream();
|
||||
|
||||
let method = quote! {
|
||||
#sig {
|
||||
#wrapper_struct::new(self, #interned_pat)
|
||||
}
|
||||
};
|
||||
|
||||
method.to_tokens(tokens);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Lookup {
|
||||
pub(crate) signature: syn::Signature,
|
||||
pub(crate) pat_and_tys: Vec<PatType>,
|
||||
pub(crate) return_ty: Type,
|
||||
pub(crate) interned_struct_path: Path,
|
||||
}
|
||||
|
||||
impl Lookup {
|
||||
pub(crate) fn prepare_signature(&mut self) {
|
||||
let sig = &self.signature;
|
||||
|
||||
let ident = format_ident!("lookup_{}", sig.ident);
|
||||
|
||||
let ty = self.pat_and_tys.to_vec();
|
||||
|
||||
let interned_key = &self.return_ty;
|
||||
|
||||
let interned_pat = ty.first().expect("at least one pat; this is a bug");
|
||||
let interned_return_ty = &interned_pat.ty;
|
||||
|
||||
self.signature = parse_quote!(
|
||||
fn #ident(&self, id: #interned_key) -> #interned_return_ty
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl ToTokens for Lookup {
|
||||
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
|
||||
let sig = &self.signature;
|
||||
|
||||
let wrapper_struct = self.interned_struct_path.to_token_stream();
|
||||
let method = quote! {
|
||||
#sig {
|
||||
#wrapper_struct::ingredient(self).data(self.as_dyn_database(), id.as_id()).0.clone()
|
||||
}
|
||||
};
|
||||
|
||||
method.to_tokens(tokens);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum Queries {
|
||||
TrackedQuery(TrackedQuery),
|
||||
InputQuery(InputQuery),
|
||||
Intern(Intern),
|
||||
Transparent(Transparent),
|
||||
}
|
||||
|
||||
impl ToTokens for Queries {
|
||||
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
|
||||
match self {
|
||||
Queries::TrackedQuery(tracked_query) => tracked_query.to_tokens(tokens),
|
||||
Queries::InputQuery(input_query) => input_query.to_tokens(tokens),
|
||||
Queries::Transparent(transparent) => transparent.to_tokens(tokens),
|
||||
Queries::Intern(intern) => intern.to_tokens(tokens),
|
||||
}
|
||||
}
|
||||
}
|
28
crates/query-group-macro/tests/arity.rs
Normal file
28
crates/query-group-macro/tests/arity.rs
Normal file
|
@ -0,0 +1,28 @@
|
|||
use query_group_macro::query_group;
|
||||
|
||||
#[query_group]
|
||||
pub trait ArityDb: salsa::Database {
|
||||
fn one(&self, a: ()) -> String;
|
||||
|
||||
fn two(&self, a: (), b: ()) -> String;
|
||||
|
||||
fn three(&self, a: (), b: (), c: ()) -> String;
|
||||
|
||||
fn none(&self) -> String;
|
||||
}
|
||||
|
||||
fn one(_db: &dyn ArityDb, _a: ()) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
fn two(_db: &dyn ArityDb, _a: (), _b: ()) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
fn three(_db: &dyn ArityDb, _a: (), _b: (), _c: ()) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
fn none(_db: &dyn ArityDb) -> String {
|
||||
String::new()
|
||||
}
|
275
crates/query-group-macro/tests/cycle.rs
Normal file
275
crates/query-group-macro/tests/cycle.rs
Normal file
|
@ -0,0 +1,275 @@
|
|||
use std::panic::UnwindSafe;
|
||||
|
||||
use expect_test::expect;
|
||||
use query_group_macro::query_group;
|
||||
use salsa::Setter;
|
||||
|
||||
/// The queries A, B, and C in `Database` can be configured
|
||||
/// to invoke one another in arbitrary ways using this
|
||||
/// enum.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
enum CycleQuery {
|
||||
None,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
AthenC,
|
||||
}
|
||||
|
||||
#[salsa::input]
|
||||
struct ABC {
|
||||
a: CycleQuery,
|
||||
b: CycleQuery,
|
||||
c: CycleQuery,
|
||||
}
|
||||
|
||||
impl CycleQuery {
|
||||
fn invoke(self, db: &dyn CycleDatabase, abc: ABC) -> Result<(), Error> {
|
||||
match self {
|
||||
CycleQuery::A => db.cycle_a(abc),
|
||||
CycleQuery::B => db.cycle_b(abc),
|
||||
CycleQuery::C => db.cycle_c(abc),
|
||||
CycleQuery::AthenC => {
|
||||
let _ = db.cycle_a(abc);
|
||||
db.cycle_c(abc)
|
||||
}
|
||||
CycleQuery::None => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[salsa::input]
|
||||
struct MyInput {}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn memoized_a(db: &dyn CycleDatabase, input: MyInput) {
|
||||
memoized_b(db, input)
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn memoized_b(db: &dyn CycleDatabase, input: MyInput) {
|
||||
memoized_a(db, input)
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn volatile_a(db: &dyn CycleDatabase, input: MyInput) {
|
||||
db.report_untracked_read();
|
||||
volatile_b(db, input)
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn volatile_b(db: &dyn CycleDatabase, input: MyInput) {
|
||||
db.report_untracked_read();
|
||||
volatile_a(db, input)
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle {
|
||||
let v = std::panic::catch_unwind(f);
|
||||
if let Err(d) = &v {
|
||||
if let Some(cycle) = d.downcast_ref::<salsa::Cycle>() {
|
||||
return cycle.clone();
|
||||
}
|
||||
}
|
||||
panic!("unexpected value: {:?}", v)
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
|
||||
struct Error {
|
||||
cycle: Vec<String>,
|
||||
}
|
||||
|
||||
#[query_group]
|
||||
trait CycleDatabase: salsa::Database {
|
||||
#[salsa::cycle(recover_a)]
|
||||
fn cycle_a(&self, abc: ABC) -> Result<(), Error>;
|
||||
|
||||
#[salsa::cycle(recover_b)]
|
||||
fn cycle_b(&self, abc: ABC) -> Result<(), Error>;
|
||||
|
||||
fn cycle_c(&self, abc: ABC) -> Result<(), Error>;
|
||||
}
|
||||
|
||||
fn cycle_a(db: &dyn CycleDatabase, abc: ABC) -> Result<(), Error> {
|
||||
abc.a(db).invoke(db, abc)
|
||||
}
|
||||
|
||||
fn recover_a(
|
||||
_db: &dyn CycleDatabase,
|
||||
cycle: &salsa::Cycle,
|
||||
_: CycleDatabaseData,
|
||||
_abc: ABC,
|
||||
) -> Result<(), Error> {
|
||||
Err(Error { cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect() })
|
||||
}
|
||||
|
||||
fn cycle_b(db: &dyn CycleDatabase, abc: ABC) -> Result<(), Error> {
|
||||
abc.b(db).invoke(db, abc)
|
||||
}
|
||||
|
||||
fn recover_b(
|
||||
_db: &dyn CycleDatabase,
|
||||
cycle: &salsa::Cycle,
|
||||
_: CycleDatabaseData,
|
||||
_abc: ABC,
|
||||
) -> Result<(), Error> {
|
||||
Err(Error { cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect() })
|
||||
}
|
||||
|
||||
fn cycle_c(db: &dyn CycleDatabase, abc: ABC) -> Result<(), Error> {
|
||||
abc.c(db).invoke(db, abc)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_memoized() {
|
||||
let db = salsa::DatabaseImpl::new();
|
||||
|
||||
let input = MyInput::new(&db);
|
||||
let cycle = extract_cycle(|| memoized_a(&db, input));
|
||||
let expected = expect![[r#"
|
||||
[
|
||||
DatabaseKeyIndex(
|
||||
IngredientIndex(
|
||||
1,
|
||||
),
|
||||
Id(0),
|
||||
),
|
||||
DatabaseKeyIndex(
|
||||
IngredientIndex(
|
||||
2,
|
||||
),
|
||||
Id(0),
|
||||
),
|
||||
]
|
||||
"#]];
|
||||
expected.assert_debug_eq(&cycle.all_participants(&db));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inner_cycle() {
|
||||
// A --> B <-- C
|
||||
// ^ |
|
||||
// +-----+
|
||||
let db = salsa::DatabaseImpl::new();
|
||||
|
||||
let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::B);
|
||||
let err = db.cycle_c(abc);
|
||||
assert!(err.is_err());
|
||||
let expected = expect![[r#"
|
||||
[
|
||||
"cycle_a_shim(Id(1400))",
|
||||
"cycle_b_shim(Id(1000))",
|
||||
]
|
||||
"#]];
|
||||
expected.assert_debug_eq(&err.unwrap_err().cycle);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_revalidate() {
|
||||
// A --> B
|
||||
// ^ |
|
||||
// +-----+
|
||||
let mut db = salsa::DatabaseImpl::new();
|
||||
let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None);
|
||||
assert!(db.cycle_a(abc).is_err());
|
||||
abc.set_b(&mut db).to(CycleQuery::A); // same value as default
|
||||
assert!(db.cycle_a(abc).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_recovery_unchanged_twice() {
|
||||
// A --> B
|
||||
// ^ |
|
||||
// +-----+
|
||||
let mut db = salsa::DatabaseImpl::new();
|
||||
let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None);
|
||||
assert!(db.cycle_a(abc).is_err());
|
||||
|
||||
abc.set_c(&mut db).to(CycleQuery::A); // force new revision
|
||||
assert!(db.cycle_a(abc).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_appears() {
|
||||
let mut db = salsa::DatabaseImpl::new();
|
||||
// A --> B
|
||||
let abc = ABC::new(&db, CycleQuery::B, CycleQuery::None, CycleQuery::None);
|
||||
assert!(db.cycle_a(abc).is_ok());
|
||||
|
||||
// A --> B
|
||||
// ^ |
|
||||
// +-----+
|
||||
abc.set_b(&mut db).to(CycleQuery::A);
|
||||
assert!(db.cycle_a(abc).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_disappears() {
|
||||
let mut db = salsa::DatabaseImpl::new();
|
||||
|
||||
// A --> B
|
||||
// ^ |
|
||||
// +-----+
|
||||
let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None);
|
||||
assert!(db.cycle_a(abc).is_err());
|
||||
|
||||
// A --> B
|
||||
abc.set_b(&mut db).to(CycleQuery::None);
|
||||
assert!(db.cycle_a(abc).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_multiple() {
|
||||
// No matter whether we start from A or B, we get the same set of participants:
|
||||
let db = salsa::DatabaseImpl::new();
|
||||
|
||||
// Configuration:
|
||||
//
|
||||
// A --> B <-- C
|
||||
// ^ | ^
|
||||
// +-----+ |
|
||||
// | |
|
||||
// +-----+
|
||||
//
|
||||
// Here, conceptually, B encounters a cycle with A and then
|
||||
// recovers.
|
||||
let abc = ABC::new(&db, CycleQuery::B, CycleQuery::AthenC, CycleQuery::A);
|
||||
|
||||
let c = db.cycle_c(abc);
|
||||
let b = db.cycle_b(abc);
|
||||
let a = db.cycle_a(abc);
|
||||
let expected = expect![[r#"
|
||||
(
|
||||
[
|
||||
"cycle_a_shim(Id(1000))",
|
||||
"cycle_b_shim(Id(1400))",
|
||||
],
|
||||
[
|
||||
"cycle_a_shim(Id(1000))",
|
||||
"cycle_b_shim(Id(1400))",
|
||||
],
|
||||
[
|
||||
"cycle_a_shim(Id(1000))",
|
||||
"cycle_b_shim(Id(1400))",
|
||||
],
|
||||
)
|
||||
"#]];
|
||||
expected.assert_debug_eq(&(c.unwrap_err().cycle, b.unwrap_err().cycle, a.unwrap_err().cycle));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_mixed_1() {
|
||||
let db = salsa::DatabaseImpl::new();
|
||||
// A --> B <-- C
|
||||
// | ^
|
||||
// +-----+
|
||||
let abc = ABC::new(&db, CycleQuery::B, CycleQuery::C, CycleQuery::B);
|
||||
|
||||
let expected = expect![[r#"
|
||||
[
|
||||
"cycle_b_shim(Id(1000))",
|
||||
"cycle_c_shim(Id(c00))",
|
||||
]
|
||||
"#]];
|
||||
expected.assert_debug_eq(&db.cycle_c(abc).unwrap_err().cycle);
|
||||
}
|
125
crates/query-group-macro/tests/hello_world.rs
Normal file
125
crates/query-group-macro/tests/hello_world.rs
Normal file
|
@ -0,0 +1,125 @@
|
|||
use expect_test::expect;
|
||||
use query_group_macro::query_group;
|
||||
|
||||
mod logger_db;
|
||||
use logger_db::LoggerDb;
|
||||
|
||||
#[query_group]
|
||||
pub trait HelloWorldDatabase: salsa::Database {
|
||||
// input
|
||||
// // input with no params
|
||||
#[salsa::input]
|
||||
fn input_string(&self) -> String;
|
||||
|
||||
// unadorned query
|
||||
fn length_query(&self, key: ()) -> usize;
|
||||
|
||||
// unadorned query
|
||||
fn length_query_with_no_params(&self) -> usize;
|
||||
|
||||
// renamed/invoke query
|
||||
#[salsa::invoke(invoke_length_query_actual)]
|
||||
fn invoke_length_query(&self, key: ()) -> usize;
|
||||
|
||||
// not a query. should not invoked
|
||||
#[salsa::transparent]
|
||||
fn transparent_length(&self, key: ()) -> usize;
|
||||
|
||||
#[salsa::transparent]
|
||||
#[salsa::invoke(transparent_and_invoke_length_actual)]
|
||||
fn transparent_and_invoke_length(&self, key: ()) -> usize;
|
||||
}
|
||||
|
||||
fn length_query(db: &dyn HelloWorldDatabase, _key: ()) -> usize {
|
||||
db.input_string().len()
|
||||
}
|
||||
|
||||
fn length_query_with_no_params(db: &dyn HelloWorldDatabase) -> usize {
|
||||
db.input_string().len()
|
||||
}
|
||||
|
||||
fn invoke_length_query_actual(db: &dyn HelloWorldDatabase, _key: ()) -> usize {
|
||||
db.input_string().len()
|
||||
}
|
||||
|
||||
fn transparent_length(db: &dyn HelloWorldDatabase, _key: ()) -> usize {
|
||||
db.input_string().len()
|
||||
}
|
||||
|
||||
fn transparent_and_invoke_length_actual(db: &dyn HelloWorldDatabase, _key: ()) -> usize {
|
||||
db.input_string().len()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unadorned_query() {
|
||||
let mut db = LoggerDb::default();
|
||||
|
||||
db.set_input_string(String::from("Hello, world!"));
|
||||
let len = db.length_query(());
|
||||
|
||||
assert_eq!(len, 13);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: create_data_HelloWorldDatabase(Id(0)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(DidValidateMemoizedValue { database_key: create_data_HelloWorldDatabase(Id(0)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: length_query_shim(Id(800)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
]"#]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_query() {
|
||||
let mut db = LoggerDb::default();
|
||||
|
||||
db.set_input_string(String::from("Hello, world!"));
|
||||
let len = db.invoke_length_query(());
|
||||
|
||||
assert_eq!(len, 13);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: create_data_HelloWorldDatabase(Id(0)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(DidValidateMemoizedValue { database_key: create_data_HelloWorldDatabase(Id(0)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: invoke_length_query_shim(Id(800)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
]"#]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transparent() {
|
||||
let mut db = LoggerDb::default();
|
||||
|
||||
db.set_input_string(String::from("Hello, world!"));
|
||||
let len = db.transparent_length(());
|
||||
|
||||
assert_eq!(len, 13);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: create_data_HelloWorldDatabase(Id(0)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(DidValidateMemoizedValue { database_key: create_data_HelloWorldDatabase(Id(0)) })",
|
||||
]"#]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transparent_invoke() {
|
||||
let mut db = LoggerDb::default();
|
||||
|
||||
db.set_input_string(String::from("Hello, world!"));
|
||||
let len = db.transparent_and_invoke_length(());
|
||||
|
||||
assert_eq!(len, 13);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: create_data_HelloWorldDatabase(Id(0)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(DidValidateMemoizedValue { database_key: create_data_HelloWorldDatabase(Id(0)) })",
|
||||
]"#]]);
|
||||
}
|
52
crates/query-group-macro/tests/interned.rs
Normal file
52
crates/query-group-macro/tests/interned.rs
Normal file
|
@ -0,0 +1,52 @@
|
|||
use query_group_macro::query_group;
|
||||
|
||||
use expect_test::expect;
|
||||
use salsa::plumbing::AsId;
|
||||
|
||||
mod logger_db;
|
||||
use logger_db::LoggerDb;
|
||||
|
||||
#[salsa::interned(no_lifetime)]
|
||||
pub struct InternedString {
|
||||
data: String,
|
||||
}
|
||||
|
||||
#[query_group]
|
||||
pub trait InternedDB: salsa::Database {
|
||||
#[salsa::interned]
|
||||
fn intern_string(&self, data: String) -> InternedString;
|
||||
|
||||
fn interned_len(&self, id: InternedString) -> usize;
|
||||
}
|
||||
|
||||
fn interned_len(db: &dyn InternedDB, id: InternedString) -> usize {
|
||||
db.lookup_intern_string(id).len()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn intern_round_trip() {
|
||||
let db = LoggerDb::default();
|
||||
|
||||
let id = db.intern_string(String::from("Hello, world!"));
|
||||
let s = db.lookup_intern_string(id);
|
||||
|
||||
assert_eq!(s.len(), 13);
|
||||
db.assert_logs(expect![[r#"[]"#]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn intern_with_query() {
|
||||
let db = LoggerDb::default();
|
||||
|
||||
let id = db.intern_string(String::from("Hello, world!"));
|
||||
let len = db.interned_len(id);
|
||||
|
||||
assert_eq!(len, 13);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: create_data_InternedDB(Id(400)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: interned_len_shim(Id(c00)) })",
|
||||
]"#]]);
|
||||
}
|
60
crates/query-group-macro/tests/logger_db.rs
Normal file
60
crates/query-group-macro/tests/logger_db.rs
Normal file
|
@ -0,0 +1,60 @@
|
|||
use std::sync::{Arc, Mutex};
|
||||
|
||||
#[salsa::db]
|
||||
#[derive(Default, Clone)]
|
||||
pub(crate) struct LoggerDb {
|
||||
storage: salsa::Storage<Self>,
|
||||
logger: Logger,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
struct Logger {
|
||||
logs: Arc<Mutex<Vec<String>>>,
|
||||
}
|
||||
|
||||
#[salsa::db]
|
||||
impl salsa::Database for LoggerDb {
|
||||
fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) {
|
||||
let event = event();
|
||||
match event.kind {
|
||||
salsa::EventKind::WillExecute { .. }
|
||||
| salsa::EventKind::WillCheckCancellation { .. }
|
||||
| salsa::EventKind::DidValidateMemoizedValue { .. }
|
||||
| salsa::EventKind::WillDiscardStaleOutput { .. }
|
||||
| salsa::EventKind::DidDiscard { .. } => {
|
||||
self.push_log(format!("salsa_event({:?})", event.kind));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LoggerDb {
|
||||
/// Log an event from inside a tracked function.
|
||||
pub(crate) fn push_log(&self, string: String) {
|
||||
self.logger.logs.lock().unwrap().push(string);
|
||||
}
|
||||
|
||||
/// Asserts what the (formatted) logs should look like,
|
||||
/// clearing the logged events. This takes `&mut self` because
|
||||
/// it is meant to be run from outside any tracked functions.
|
||||
pub(crate) fn assert_logs(&self, expected: expect_test::Expect) {
|
||||
let logs = std::mem::take(&mut *self.logger.logs.lock().unwrap());
|
||||
expected.assert_eq(&format!("{:#?}", logs));
|
||||
}
|
||||
}
|
||||
|
||||
/// Test the logger database.
|
||||
///
|
||||
/// This test isn't very interesting, but it *does* remove a dead code warning.
|
||||
#[test]
|
||||
fn test_logger_db() {
|
||||
let db = LoggerDb::default();
|
||||
db.push_log("test".to_string());
|
||||
db.assert_logs(expect_test::expect![
|
||||
r#"
|
||||
[
|
||||
"test",
|
||||
]"#
|
||||
]);
|
||||
}
|
67
crates/query-group-macro/tests/lru.rs
Normal file
67
crates/query-group-macro/tests/lru.rs
Normal file
|
@ -0,0 +1,67 @@
|
|||
use expect_test::expect;
|
||||
|
||||
mod logger_db;
|
||||
use logger_db::LoggerDb;
|
||||
use query_group_macro::query_group;
|
||||
|
||||
#[query_group]
|
||||
pub trait LruDB: salsa::Database {
|
||||
// // input with no params
|
||||
#[salsa::input]
|
||||
fn input_string(&self) -> String;
|
||||
|
||||
#[salsa::lru(16)]
|
||||
fn length_query(&self, key: ()) -> usize;
|
||||
|
||||
#[salsa::lru(16)]
|
||||
#[salsa::invoke(invoked_query)]
|
||||
fn length_query_invoke(&self, key: ()) -> usize;
|
||||
}
|
||||
|
||||
fn length_query(db: &dyn LruDB, _key: ()) -> usize {
|
||||
db.input_string().len()
|
||||
}
|
||||
|
||||
fn invoked_query(db: &dyn LruDB, _key: ()) -> usize {
|
||||
db.input_string().len()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plain_lru() {
|
||||
let mut db = LoggerDb::default();
|
||||
|
||||
db.set_input_string(String::from("Hello, world!"));
|
||||
let len = db.length_query(());
|
||||
|
||||
assert_eq!(len, 13);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: create_data_LruDB(Id(0)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(DidValidateMemoizedValue { database_key: create_data_LruDB(Id(0)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: length_query_shim(Id(800)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
]"#]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_lru() {
|
||||
let mut db = LoggerDb::default();
|
||||
|
||||
db.set_input_string(String::from("Hello, world!"));
|
||||
let len = db.length_query_invoke(());
|
||||
|
||||
assert_eq!(len, 13);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: create_data_LruDB(Id(0)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(DidValidateMemoizedValue { database_key: create_data_LruDB(Id(0)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: length_query_invoke_shim(Id(800)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
]"#]]);
|
||||
}
|
23
crates/query-group-macro/tests/multiple_dbs.rs
Normal file
23
crates/query-group-macro/tests/multiple_dbs.rs
Normal file
|
@ -0,0 +1,23 @@
|
|||
use query_group_macro::query_group;
|
||||
|
||||
#[query_group]
|
||||
pub trait DatabaseOne: salsa::Database {
|
||||
#[salsa::input]
|
||||
fn input_string(&self) -> String;
|
||||
|
||||
// unadorned query
|
||||
fn length(&self, key: ()) -> usize;
|
||||
}
|
||||
|
||||
#[query_group]
|
||||
pub trait DatabaseTwo: DatabaseOne {
|
||||
fn second_length(&self, key: ()) -> usize;
|
||||
}
|
||||
|
||||
fn length(db: &dyn DatabaseOne, _key: ()) -> usize {
|
||||
db.input_string().len()
|
||||
}
|
||||
|
||||
fn second_length(db: &dyn DatabaseTwo, _key: ()) -> usize {
|
||||
db.input_string().len()
|
||||
}
|
115
crates/query-group-macro/tests/old_and_new.rs
Normal file
115
crates/query-group-macro/tests/old_and_new.rs
Normal file
|
@ -0,0 +1,115 @@
|
|||
use expect_test::expect;
|
||||
|
||||
mod logger_db;
|
||||
use logger_db::LoggerDb;
|
||||
use query_group_macro::query_group;
|
||||
|
||||
#[salsa::input]
|
||||
struct Input {
|
||||
str: String,
|
||||
}
|
||||
|
||||
#[query_group]
|
||||
trait PartialMigrationDatabase: salsa::Database {
|
||||
fn length_query(&self, input: Input) -> usize;
|
||||
|
||||
// renamed/invoke query
|
||||
#[salsa::invoke(invoke_length_query_actual)]
|
||||
fn invoke_length_query(&self, input: Input) -> usize;
|
||||
|
||||
// invoke tracked function
|
||||
#[salsa::invoke(invoke_length_tracked_actual)]
|
||||
fn invoke_length_tracked(&self, input: Input) -> usize;
|
||||
}
|
||||
|
||||
fn length_query(db: &dyn PartialMigrationDatabase, input: Input) -> usize {
|
||||
input.str(db).len()
|
||||
}
|
||||
|
||||
fn invoke_length_query_actual(db: &dyn PartialMigrationDatabase, input: Input) -> usize {
|
||||
input.str(db).len()
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn invoke_length_tracked_actual(db: &dyn PartialMigrationDatabase, input: Input) -> usize {
|
||||
input.str(db).len()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unadorned_query() {
|
||||
let db = LoggerDb::default();
|
||||
|
||||
let input = Input::new(&db, String::from("Hello, world!"));
|
||||
let len = db.length_query(input);
|
||||
|
||||
assert_eq!(len, 13);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: create_data_PartialMigrationDatabase(Id(400)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: length_query_shim(Id(c00)) })",
|
||||
]"#]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_query() {
|
||||
let db = LoggerDb::default();
|
||||
|
||||
let input = Input::new(&db, String::from("Hello, world!"));
|
||||
let len = db.invoke_length_query(input);
|
||||
|
||||
assert_eq!(len, 13);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: create_data_PartialMigrationDatabase(Id(400)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: invoke_length_query_shim(Id(c00)) })",
|
||||
]"#]]);
|
||||
}
|
||||
|
||||
// todo: does this even make sense?
|
||||
#[test]
|
||||
fn invoke_tracked_query() {
|
||||
let db = LoggerDb::default();
|
||||
|
||||
let input = Input::new(&db, String::from("Hello, world!"));
|
||||
let len = db.invoke_length_tracked(input);
|
||||
|
||||
assert_eq!(len, 13);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: create_data_PartialMigrationDatabase(Id(400)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: invoke_length_tracked_shim(Id(c00)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: invoke_length_tracked_actual(Id(0)) })",
|
||||
]"#]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_salsa_baseline() {
|
||||
let db = LoggerDb::default();
|
||||
|
||||
#[salsa::input]
|
||||
struct Input {
|
||||
str: String,
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn new_salsa_length_query(db: &dyn PartialMigrationDatabase, input: Input) -> usize {
|
||||
input.str(db).len()
|
||||
}
|
||||
|
||||
let input = Input::new(&db, String::from("Hello, world!"));
|
||||
let len = new_salsa_length_query(&db, input);
|
||||
|
||||
assert_eq!(len, 13);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: new_salsa_length_query(Id(0)) })",
|
||||
]"#]]);
|
||||
}
|
50
crates/query-group-macro/tests/result.rs
Normal file
50
crates/query-group-macro/tests/result.rs
Normal file
|
@ -0,0 +1,50 @@
|
|||
mod logger_db;
|
||||
use expect_test::expect;
|
||||
use logger_db::LoggerDb;
|
||||
|
||||
use query_group_macro::query_group;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct Error;
|
||||
|
||||
#[query_group]
|
||||
pub trait ResultDatabase: salsa::Database {
|
||||
#[salsa::input]
|
||||
fn input_string(&self) -> String;
|
||||
|
||||
fn length(&self, key: ()) -> Result<usize, Error>;
|
||||
|
||||
fn length2(&self, key: ()) -> Result<usize, Error>;
|
||||
}
|
||||
|
||||
fn length(db: &dyn ResultDatabase, _key: ()) -> Result<usize, Error> {
|
||||
Ok(db.input_string().len())
|
||||
}
|
||||
|
||||
fn length2(db: &dyn ResultDatabase, _key: ()) -> Result<usize, Error> {
|
||||
Ok(db.input_string().len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_queries_with_results() {
|
||||
let mut db = LoggerDb::default();
|
||||
let input = "hello";
|
||||
db.set_input_string(input.to_owned());
|
||||
assert_eq!(db.length(()), Ok(input.len()));
|
||||
assert_eq!(db.length2(()), Ok(input.len()));
|
||||
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: create_data_ResultDatabase(Id(0)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(DidValidateMemoizedValue { database_key: create_data_ResultDatabase(Id(0)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: length_shim(Id(800)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: length2_shim(Id(c00)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
]"#]]);
|
||||
}
|
19
crates/query-group-macro/tests/supertrait.rs
Normal file
19
crates/query-group-macro/tests/supertrait.rs
Normal file
|
@ -0,0 +1,19 @@
|
|||
use query_group_macro::query_group;
|
||||
|
||||
#[salsa::db]
|
||||
pub trait SourceDb: salsa::Database {
|
||||
/// Text of the file.
|
||||
fn file_text(&self, id: usize) -> String;
|
||||
}
|
||||
|
||||
#[query_group]
|
||||
pub trait RootDb: SourceDb {
|
||||
fn parse(&self, id: usize) -> String;
|
||||
}
|
||||
|
||||
fn parse(db: &dyn RootDb, id: usize) -> String {
|
||||
// this is the test: does the following compile?
|
||||
db.file_text(id);
|
||||
|
||||
String::new()
|
||||
}
|
38
crates/query-group-macro/tests/tuples.rs
Normal file
38
crates/query-group-macro/tests/tuples.rs
Normal file
|
@ -0,0 +1,38 @@
|
|||
use query_group_macro::query_group;
|
||||
|
||||
mod logger_db;
|
||||
use expect_test::expect;
|
||||
use logger_db::LoggerDb;
|
||||
|
||||
#[query_group]
|
||||
pub trait HelloWorldDatabase: salsa::Database {
|
||||
#[salsa::input]
|
||||
fn input_string(&self) -> String;
|
||||
|
||||
fn length_query(&self, key: ()) -> (usize, usize);
|
||||
}
|
||||
|
||||
fn length_query(db: &dyn HelloWorldDatabase, _key: ()) -> (usize, usize) {
|
||||
let len = db.input_string().len();
|
||||
(len, len)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query() {
|
||||
let mut db = LoggerDb::default();
|
||||
|
||||
db.set_input_string(String::from("Hello, world!"));
|
||||
let len = db.length_query(());
|
||||
|
||||
assert_eq!(len, (13, 13));
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: create_data_HelloWorldDatabase(Id(0)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(DidValidateMemoizedValue { database_key: create_data_HelloWorldDatabase(Id(0)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
"salsa_event(WillExecute { database_key: length_query_shim(Id(800)) })",
|
||||
"salsa_event(WillCheckCancellation)",
|
||||
]"#]]);
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue