Changed return_ref syntax to returns(as_ref) and returns(cloned) (#772)

* Changed `return_ref` syntax to `returns(as_ref)` and `returns(cloned)`

* Implement

* renamed module for return_mode

* Rename macro, fix docs, add tests, validate return modes

* Cargo fmt

---------

Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
CheaterCodes 2025-05-09 09:28:54 +02:00 committed by GitHub
parent d1da99132d
commit 13a2bd7461
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
54 changed files with 538 additions and 172 deletions

View file

@ -14,8 +14,8 @@
mod macro_if;
mod maybe_backdate;
mod maybe_clone;
mod maybe_default;
mod return_mode;
mod setup_accumulator_impl;
mod setup_input_struct;
mod setup_interned_struct;

View file

@ -2,7 +2,7 @@
#[macro_export]
macro_rules! maybe_backdate {
(
($maybe_clone:ident, no_backdate, $maybe_default:ident),
($return_mode:ident, no_backdate, $maybe_default:ident),
$field_ty:ty,
$old_field_place:expr,
$new_field_place:expr,
@ -20,7 +20,7 @@ macro_rules! maybe_backdate {
};
(
($maybe_clone:ident, backdate, $maybe_default:ident),
($return_mode:ident, backdate, $maybe_default:ident),
$field_ty:ty,
$old_field_place:expr,
$new_field_place:expr,

View file

@ -1,40 +0,0 @@
/// Generate either `field_ref_expr` or a clone of that expr.
///
/// Used when generating field getters.
#[macro_export]
macro_rules! maybe_clone {
(
(no_clone, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
$field_ref_expr
};
(
(clone, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
std::clone::Clone::clone($field_ref_expr)
};
}
#[macro_export]
macro_rules! maybe_cloned_ty {
(
(no_clone, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
& $db_lt $field_ty
};
(
(clone, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
$field_ty
};
}

View file

@ -4,7 +4,7 @@
#[macro_export]
macro_rules! maybe_default {
(
($maybe_clone:ident, $maybe_backdate:ident, default),
($return_mode:ident, $maybe_backdate:ident, default),
$field_ty:ty,
$field_ref_expr:expr,
) => {
@ -12,7 +12,7 @@ macro_rules! maybe_default {
};
(
($maybe_clone:ident, $maybe_backdate:ident, required),
($return_mode:ident, $maybe_backdate:ident, required),
$field_ty:ty,
$field_ref_expr:expr,
) => {
@ -22,11 +22,11 @@ macro_rules! maybe_default {
#[macro_export]
macro_rules! maybe_default_tt {
(($maybe_clone:ident, $maybe_backdate:ident, default) => $($t:tt)*) => {
(($return_mode:ident, $maybe_backdate:ident, default) => $($t:tt)*) => {
$($t)*
};
(($maybe_clone:ident, $maybe_backdate:ident, required) => $($t:tt)*) => {
(($return_mode:ident, $maybe_backdate:ident, required) => $($t:tt)*) => {
};
}

View file

@ -0,0 +1,104 @@
/// Generate the expression for the return type, depending on the return mode defined in [`salsa-macros::options::Options::returns`]
///
/// Used when generating field getters.
#[macro_export]
macro_rules! return_mode_expression {
(
(copy, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
*$field_ref_expr
};
(
(clone, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
::core::clone::Clone::clone($field_ref_expr)
};
(
(ref, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
$field_ref_expr
};
(
(deref, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
::core::ops::Deref::deref($field_ref_expr)
};
(
(as_ref, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
::salsa::SalsaAsRef::as_ref($field_ref_expr)
};
(
(as_deref, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
::salsa::SalsaAsDeref::as_deref($field_ref_expr)
};
}
#[macro_export]
macro_rules! return_mode_ty {
(
(copy, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
$field_ty
};
(
(clone, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
$field_ty
};
(
(ref, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
& $db_lt $field_ty
};
(
(deref, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
& $db_lt <$field_ty as ::core::ops::Deref>::Target
};
(
(as_ref, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
<$field_ty as ::salsa::SalsaAsRef>::AsRef<$db_lt>
};
(
(as_deref, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
<$field_ty as ::salsa::SalsaAsDeref>::AsDeref<$db_lt>
};
}

View file

@ -182,7 +182,7 @@ macro_rules! setup_input_struct {
}
$(
$field_getter_vis fn $field_getter_id<'db, $Db>(self, db: &'db $Db) -> $zalsa::maybe_cloned_ty!($field_option, 'db, $field_ty)
$field_getter_vis fn $field_getter_id<'db, $Db>(self, db: &'db $Db) -> $zalsa::return_mode_ty!($field_option, 'db, $field_ty)
where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database,
@ -192,7 +192,7 @@ macro_rules! setup_input_struct {
self,
$field_index,
);
$zalsa::maybe_clone!(
$zalsa::return_mode_expression!(
$field_option,
$field_ty,
&fields.$field_index,

View file

@ -215,13 +215,13 @@ macro_rules! setup_interned_struct {
}
$(
$field_getter_vis fn $field_getter_id<$Db>(self, db: &'db $Db) -> $zalsa::maybe_cloned_ty!($field_option, 'db, $field_ty)
$field_getter_vis fn $field_getter_id<$Db>(self, db: &'db $Db) -> $zalsa::return_mode_ty!($field_option, 'db, $field_ty)
where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database,
{
let fields = $Configuration::ingredient(db).fields(db.as_dyn_database(), self);
$zalsa::maybe_clone!(
$zalsa::return_mode_expression!(
$field_option,
$field_ty,
&fields.$field_index,

View file

@ -55,8 +55,8 @@ macro_rules! setup_tracked_fn {
// LRU capacity (a literal, maybe 0)
lru: $lru:tt,
// True if we `return_ref` flag was given to the function
return_ref: $return_ref:tt,
// The return mode for the function, see `salsa_macros::options::Option::returns`
return_mode: $return_mode:tt,
assert_return_type_is_update: {$($assert_return_type_is_update:tt)*},
@ -80,13 +80,7 @@ macro_rules! setup_tracked_fn {
$vis fn $fn_name<$db_lt>(
$db: &$db_lt dyn $Db,
$($input_id: $input_ty,)*
) -> salsa::plumbing::macro_if! {
if $return_ref {
&$db_lt $output_ty
} else {
$output_ty
}
} {
) -> salsa::plumbing::return_mode_ty!(($return_mode, __, __), $db_lt, $output_ty) {
use salsa::plumbing as $zalsa;
struct $Configuration;
@ -372,13 +366,7 @@ macro_rules! setup_tracked_fn {
}
};
$zalsa::macro_if! {
if $return_ref {
result
} else {
<$output_ty as std::clone::Clone>::clone(result)
}
}
$zalsa::return_mode_expression!(($return_mode, __, __), $output_ty, result,)
})
}
// The struct needs be last in the macro expansion in order to make the tracked

View file

@ -52,24 +52,24 @@ macro_rules! setup_tracked_struct {
// A set of "field options" for each tracked field.
//
// Each field option is a tuple `(maybe_clone, maybe_backdate)` where:
// Each field option is a tuple `(return_mode, maybe_backdate)` where:
//
// * `maybe_clone` is either the identifier `clone` or `no_clone`
// * `return_mode` is an identifier as specified in `salsa_macros::options::Option::returns`
// * `maybe_backdate` is either the identifier `backdate` or `no_backdate`
//
// These are used to drive conditional logic for each field via recursive macro invocation
// (see e.g. @maybe_clone below).
// (see e.g. @return_mode below).
tracked_options: [$($tracked_option:tt),*],
// A set of "field options" for each untracked field.
//
// Each field option is a tuple `(maybe_clone, maybe_backdate)` where:
// Each field option is a tuple `(return_mode, maybe_backdate)` where:
//
// * `maybe_clone` is either the identifier `clone` or `no_clone`
// * `return_mode` is an identifier as specified in `salsa_macros::options::Option::returns`
// * `maybe_backdate` is either the identifier `backdate` or `no_backdate`
//
// These are used to drive conditional logic for each field via recursive macro invocation
// (see e.g. @maybe_clone below).
// (see e.g. @return_mode below).
untracked_options: [$($untracked_option:tt),*],
// Number of tracked fields.
@ -260,14 +260,14 @@ macro_rules! setup_tracked_struct {
}
$(
$tracked_getter_vis fn $tracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::maybe_cloned_ty!($tracked_option, $db_lt, $tracked_ty)
$tracked_getter_vis fn $tracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::return_mode_ty!($tracked_option, $db_lt, $tracked_ty)
where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database,
{
let db = db.as_dyn_database();
let fields = $Configuration::ingredient(db).tracked_field(db, self, $relative_tracked_index);
$crate::maybe_clone!(
$crate::return_mode_expression!(
$tracked_option,
$tracked_ty,
&fields.$absolute_tracked_index,
@ -276,14 +276,14 @@ macro_rules! setup_tracked_struct {
)*
$(
$untracked_getter_vis fn $untracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::maybe_cloned_ty!($untracked_option, $db_lt, $untracked_ty)
$untracked_getter_vis fn $untracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::return_mode_ty!($untracked_option, $db_lt, $untracked_ty)
where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database,
{
let db = db.as_dyn_database();
let fields = $Configuration::ingredient(db).untracked_field(db, self);
$crate::maybe_clone!(
$crate::return_mode_expression!(
$untracked_option,
$untracked_ty,
&fields.$absolute_untracked_index,

View file

@ -29,7 +29,7 @@ pub(crate) fn accumulator(
struct Accumulator;
impl AllowedOptions for Accumulator {
const RETURN_REF: bool = false;
const RETURNS: bool = false;
const SPECIFY: bool = false;
const NO_EQ: bool = false;
const DEBUG: bool = false;

View file

@ -33,7 +33,7 @@ type InputArgs = Options<InputStruct>;
struct InputStruct;
impl crate::options::AllowedOptions for InputStruct {
const RETURN_REF: bool = false;
const RETURNS: bool = false;
const SPECIFY: bool = false;

View file

@ -33,7 +33,7 @@ type InternedArgs = Options<InternedStruct>;
struct InternedStruct;
impl crate::options::AllowedOptions for InternedStruct {
const RETURN_REF: bool = false;
const RETURNS: bool = false;
const SPECIFY: bool = false;

View file

@ -1,6 +1,7 @@
use std::marker::PhantomData;
use syn::ext::IdentExt;
use syn::parenthesized;
use syn::spanned::Spanned;
/// "Options" are flags that can be supplied to the various salsa related
@ -10,10 +11,11 @@ use syn::spanned::Spanned;
/// trait.
#[derive(Debug)]
pub(crate) struct Options<A: AllowedOptions> {
/// The `return_ref` option is used to signal that field/return type is "by ref"
/// The `returns` option is used to configure the "return mode" for the field/function.
/// This may be one of `copy`, `clone`, `ref`, `as_ref`, `as_deref`.
///
/// If this is `Some`, the value is the `ref` identifier.
pub return_ref: Option<syn::Ident>,
/// If this is `Some`, the value is the ident representing the selected mode.
pub returns: Option<syn::Ident>,
/// The `no_eq` option is used to signal that a given field does not implement
/// the `Eq` trait and cannot be compared for equality.
@ -96,7 +98,7 @@ pub(crate) struct Options<A: AllowedOptions> {
impl<A: AllowedOptions> Default for Options<A> {
fn default() -> Self {
Self {
return_ref: Default::default(),
returns: Default::default(),
specify: Default::default(),
no_eq: Default::default(),
debug: Default::default(),
@ -118,7 +120,7 @@ impl<A: AllowedOptions> Default for Options<A> {
/// These flags determine which options are allowed in a given context
pub(crate) trait AllowedOptions {
const RETURN_REF: bool;
const RETURNS: bool;
const SPECIFY: bool;
const NO_EQ: bool;
const DEBUG: bool;
@ -144,18 +146,21 @@ impl<A: AllowedOptions> syn::parse::Parse for Options<A> {
while !input.is_empty() {
let ident: syn::Ident = syn::Ident::parse_any(input)?;
if ident == "return_ref" {
if A::RETURN_REF {
if let Some(old) = options.return_ref.replace(ident) {
if ident == "returns" {
let content;
parenthesized!(content in input);
let mode = syn::Ident::parse_any(&content)?;
if A::RETURNS {
if let Some(old) = options.returns.replace(mode) {
return Err(syn::Error::new(
old.span(),
"option `return_ref` provided twice",
"option `returns` provided twice",
));
}
} else {
return Err(syn::Error::new(
ident.span(),
"`return_ref` option not allowed here",
"`returns` option not allowed here",
));
}
} else if ident == "no_eq" {

View file

@ -26,6 +26,7 @@
//! * this could be optimized, particularly for interned fields
use proc_macro2::{Ident, Literal, Span, TokenStream};
use syn::{ext::IdentExt, spanned::Spanned};
use crate::db_lifetime;
use crate::options::{AllowedOptions, Options};
@ -58,19 +59,22 @@ pub(crate) struct SalsaField<'s> {
pub(crate) has_tracked_attr: bool,
pub(crate) has_default_attr: bool,
pub(crate) has_ref_attr: bool,
pub(crate) returns: syn::Ident,
pub(crate) has_no_eq_attr: bool,
get_name: syn::Ident,
set_name: syn::Ident,
}
const BANNED_FIELD_NAMES: &[&str] = &["from", "new"];
const ALLOWED_RETURN_MODES: &[&str] = &["copy", "clone", "ref", "deref", "as_ref", "as_deref"];
#[allow(clippy::type_complexity)]
pub(crate) const FIELD_OPTION_ATTRIBUTES: &[(&str, fn(&syn::Attribute, &mut SalsaField))] = &[
("tracked", |_, ef| ef.has_tracked_attr = true),
("default", |_, ef| ef.has_default_attr = true),
("return_ref", |_, ef| ef.has_ref_attr = true),
("returns", |attr, ef| {
ef.returns = attr.parse_args_with(syn::Ident::parse_any).unwrap();
}),
("no_eq", |_, ef| ef.has_no_eq_attr = true),
("get", |attr, ef| {
ef.get_name = attr.parse_args().unwrap();
@ -364,10 +368,11 @@ impl<'s> SalsaField<'s> {
let get_name = Ident::new(&field_name_str, field_name.span());
let set_name = Ident::new(&format!("set_{field_name_str}",), field_name.span());
let returns = Ident::new("clone", field.span());
let mut result = SalsaField {
field,
has_tracked_attr: false,
has_ref_attr: false,
returns,
has_default_attr: false,
has_no_eq_attr: false,
get_name,
@ -383,15 +388,22 @@ impl<'s> SalsaField<'s> {
}
}
// Validate return mode
if !ALLOWED_RETURN_MODES
.iter()
.any(|mode| mode == &result.returns.to_string())
{
return Err(syn::Error::new(
result.returns.span(),
format!("Invalid return mode. Allowed modes are: {ALLOWED_RETURN_MODES:?}"),
));
}
Ok(result)
}
fn options(&self) -> TokenStream {
let clone_ident = if self.has_ref_attr {
syn::Ident::new("no_clone", Span::call_site())
} else {
syn::Ident::new("clone", Span::call_site())
};
let returns = &self.returns;
let backdate_ident = if self.has_no_eq_attr {
syn::Ident::new("no_backdate", Span::call_site())
@ -405,6 +417,6 @@ impl<'s> SalsaField<'s> {
syn::Ident::new("required", Span::call_site())
};
quote!((#clone_ident, #backdate_ident, #default_ident))
quote!((#returns, #backdate_ident, #default_ident))
}
}

View file

@ -1,7 +1,7 @@
use proc_macro2::{Literal, Span, TokenStream};
use quote::ToTokens;
use syn::spanned::Spanned;
use syn::ItemFn;
use syn::{Ident, ItemFn};
use crate::hygiene::Hygiene;
use crate::options::Options;
@ -26,7 +26,7 @@ pub type FnArgs = Options<TrackedFn>;
pub struct TrackedFn;
impl crate::options::AllowedOptions for TrackedFn {
const RETURN_REF: bool = true;
const RETURNS: bool = true;
const SPECIFY: bool = true;
@ -67,6 +67,8 @@ struct ValidFn<'item> {
db_path: &'item syn::Path,
}
const ALLOWED_RETURN_MODES: &[&str] = &["copy", "clone", "ref", "deref", "as_ref", "as_deref"];
#[allow(non_snake_case)]
impl Macro {
fn try_fn(&self, item: syn::ItemFn) -> syn::Result<TokenStream> {
@ -146,7 +148,22 @@ impl Macro {
let lru = Literal::usize_unsuffixed(self.args.lru.unwrap_or(0));
let return_ref: bool = self.args.return_ref.is_some();
let return_mode = self
.args
.returns
.clone()
.unwrap_or(Ident::new("clone", Span::call_site()));
// Validate return mode
if !ALLOWED_RETURN_MODES
.iter()
.any(|mode| mode == &return_mode.to_string())
{
return Err(syn::Error::new(
return_mode.span(),
format!("Invalid return mode. Allowed modes are: {ALLOWED_RETURN_MODES:?}"),
));
}
// The path expression is responsible for emitting the primary span in the diagnostic we
// want, so by uniformly using `output_ty.span()` we ensure that the diagnostic is emitted
@ -183,7 +200,7 @@ impl Macro {
values_equal: {#eq},
needs_interner: #needs_interner,
lru: #lru,
return_ref: #return_ref,
return_mode: #return_mode,
assert_return_type_is_update: { #assert_return_type_is_update },
unused_names: [
#zalsa,

View file

@ -296,13 +296,13 @@ impl Macro {
args: &FnArgs,
db_lt: &Option<syn::Lifetime>,
) -> syn::Result<()> {
if let Some(return_ref) = &args.return_ref {
if let Some(returns) = &args.returns {
if let syn::ReturnType::Type(_, t) = &mut sig.output {
**t = parse_quote!(& #db_lt #t)
} else {
return Err(syn::Error::new_spanned(
return_ref,
"return_ref attribute requires explicit return type",
returns,
"returns attribute requires explicit return type",
));
};
}

View file

@ -28,7 +28,7 @@ type TrackedArgs = Options<TrackedStruct>;
struct TrackedStruct;
impl crate::options::AllowedOptions for TrackedStruct {
const RETURN_REF: bool = false;
const RETURNS: bool = false;
const SPECIFY: bool = false;