Changed return_ref syntax to returns(as_ref) and returns(cloned) (#772)

* Changed `return_ref` syntax to `returns(as_ref)` and `returns(cloned)`

* Implement

* renamed module for return_mode

* Rename macro, fix docs, add tests, validate return modes

* Cargo fmt

---------

Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
CheaterCodes 2025-05-09 09:28:54 +02:00 committed by GitHub
parent d1da99132d
commit 13a2bd7461
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
54 changed files with 538 additions and 172 deletions

View file

@ -10,7 +10,7 @@ include!("shims/global_alloc_overwrite.rs");
#[salsa::input]
pub struct Input {
#[return_ref]
#[returns(ref)]
pub text: String,
}
@ -22,7 +22,7 @@ pub fn length(db: &dyn salsa::Database, input: Input) -> usize {
#[salsa::interned]
pub struct InternedInput<'db> {
#[return_ref]
#[returns(ref)]
pub text: String,
}

View file

@ -15,7 +15,7 @@ struct Tracked<'db> {
number: usize,
}
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
#[inline(never)]
fn index<'db>(db: &'db dyn salsa::Database, input: Input) -> Vec<Tracked<'db>> {
(0..input.field(db)).map(|i| Tracked::new(db, i)).collect()

View file

@ -79,7 +79,7 @@ pub struct ProgramFile(salsa::Id);
This means that, when you have a `ProgramFile`, you can easily copy it around and put it wherever you like.
To actually read any of its fields, however, you will need to use the database and a getter method.
### Reading fields and `return_ref`
### Reading fields and `returns(ref)`
You can access the value of an input's fields by using the getter method.
As this is only reading the field, it just needs a `&`-reference to the database:
@ -89,13 +89,13 @@ let contents: String = file.contents(&db);
```
Invoking the accessor clones the value from the database.
Sometimes this is not what you want, so you can annotate fields with `#[return_ref]` to indicate that they should return a reference into the database instead:
Sometimes this is not what you want, so you can annotate fields with `#[returns(ref)]` to indicate that they should return a reference into the database instead:
```rust
#[salsa::input]
pub struct ProgramFile {
pub path: PathBuf,
#[return_ref]
#[returns(ref)]
pub contents: String,
}
```
@ -145,7 +145,7 @@ Tracked functions have to follow a particular structure:
- They must take a "Salsa struct" as the second argument -- in our example, this is an input struct, but there are other kinds of Salsa structs we'll describe shortly.
- They _can_ take additional arguments, but it's faster and better if they don't.
Tracked functions can return any clone-able type. A clone is required since, when the value is cached, the result will be cloned out of the database. Tracked functions can also be annotated with `#[return_ref]` if you would prefer to return a reference into the database instead (if `parse_file` were so annotated, then callers would actually get back an `&Ast`, for example).
Tracked functions can return any clone-able type. A clone is required since, when the value is cached, the result will be cloned out of the database. Tracked functions can also be annotated with `#[returns(ref)]` if you would prefer to return a reference into the database instead (if `parse_file` were so annotated, then callers would actually get back an `&Ast`, for example).
## Tracked structs
@ -158,7 +158,7 @@ Example:
```rust
#[salsa::tracked]
struct Ast<'db> {
#[return_ref]
#[returns(ref)]
top_level_items: Vec<Item>,
}
```
@ -252,7 +252,7 @@ Most compilers, for example, will define a type to represent a user identifier:
```rust
#[salsa::interned]
struct Word {
#[return_ref]
#[returns(ref)]
pub text: String,
}
```
@ -269,7 +269,7 @@ let w3 = Word::new(db, "foo".to_string());
When you create two interned structs with the same field values, you are guaranteed to get back the same integer id. So here, we know that `assert_eq!(w1, w3)` is true and `assert_ne!(w1, w2)`.
You can access the fields of an interned struct using a getter, like `word.text(db)`. These getters respect the `#[return_ref]` annotation. Like tracked structs, the fields of interned structs are immutable.
You can access the fields of an interned struct using a getter, like `word.text(db)`. These getters respect the `#[returns(ref)]` annotation. Like tracked structs, the fields of interned structs are immutable.
## Accumulators

View file

@ -20,7 +20,7 @@ fn parse_module(db: &dyn Db, module: Module) -> Ast {
Ast::parse_text(module_text)
}
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
fn module_text(db: &dyn Db, module: Module) -> String {
panic!("text for module `{module:?}` not set")
}

View file

@ -92,7 +92,7 @@ Apart from the fields being immutable, the API for working with a tracked struct
- You can create a new value by using `new`: e.g., `Program::new(&db, some_statements)`
- You use a getter to read the value of a field, just like with an input (e.g., `my_func.statements(db)` to read the `statements` field).
- In this case, the field is tagged as `#[return_ref]`, which means that the getter will return a `&Vec<Statement>`, instead of cloning the vector.
- In this case, the field is tagged as `#[returns(ref)]`, which means that the getter will return a `&Vec<Statement>`, instead of cloning the vector.
### The `'db` lifetime

View file

@ -61,12 +61,12 @@ Tracked functions may take other arguments as well, though our examples here do
Functions that take additional arguments are less efficient and flexible.
It's generally better to structure tracked functions as functions of a single Salsa struct if possible.
### The `return_ref` annotation
### The `returns(ref)` annotation
You may have noticed that `parse_statements` is tagged with `#[salsa::tracked(return_ref)]`.
You may have noticed that `parse_statements` is tagged with `#[salsa::tracked(returns(ref))]`.
Ordinarily, when you call a tracked function, the result you get back is cloned out of the database.
The `return_ref` attribute means that a reference into the database is returned instead.
The `returns(ref)` attribute means that a reference into the database is returned instead.
So, when called, `parse_statements` will return an `&Vec<Statement>` rather than cloning the `Vec`.
This is useful as a performance optimization.
(You may recall the `return_ref` annotation from the [ir](./ir.md) section of the tutorial,
(You may recall the `returns(ref)` annotation from the [ir](./ir.md) section of the tutorial,
where it was placed on struct fields, with roughly the same meaning.)

View file

@ -14,8 +14,8 @@
mod macro_if;
mod maybe_backdate;
mod maybe_clone;
mod maybe_default;
mod return_mode;
mod setup_accumulator_impl;
mod setup_input_struct;
mod setup_interned_struct;

View file

@ -2,7 +2,7 @@
#[macro_export]
macro_rules! maybe_backdate {
(
($maybe_clone:ident, no_backdate, $maybe_default:ident),
($return_mode:ident, no_backdate, $maybe_default:ident),
$field_ty:ty,
$old_field_place:expr,
$new_field_place:expr,
@ -20,7 +20,7 @@ macro_rules! maybe_backdate {
};
(
($maybe_clone:ident, backdate, $maybe_default:ident),
($return_mode:ident, backdate, $maybe_default:ident),
$field_ty:ty,
$old_field_place:expr,
$new_field_place:expr,

View file

@ -1,40 +0,0 @@
/// Generate either `field_ref_expr` or a clone of that expr.
///
/// Used when generating field getters.
#[macro_export]
macro_rules! maybe_clone {
(
(no_clone, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
$field_ref_expr
};
(
(clone, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
std::clone::Clone::clone($field_ref_expr)
};
}
#[macro_export]
macro_rules! maybe_cloned_ty {
(
(no_clone, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
& $db_lt $field_ty
};
(
(clone, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
$field_ty
};
}

View file

@ -4,7 +4,7 @@
#[macro_export]
macro_rules! maybe_default {
(
($maybe_clone:ident, $maybe_backdate:ident, default),
($return_mode:ident, $maybe_backdate:ident, default),
$field_ty:ty,
$field_ref_expr:expr,
) => {
@ -12,7 +12,7 @@ macro_rules! maybe_default {
};
(
($maybe_clone:ident, $maybe_backdate:ident, required),
($return_mode:ident, $maybe_backdate:ident, required),
$field_ty:ty,
$field_ref_expr:expr,
) => {
@ -22,11 +22,11 @@ macro_rules! maybe_default {
#[macro_export]
macro_rules! maybe_default_tt {
(($maybe_clone:ident, $maybe_backdate:ident, default) => $($t:tt)*) => {
(($return_mode:ident, $maybe_backdate:ident, default) => $($t:tt)*) => {
$($t)*
};
(($maybe_clone:ident, $maybe_backdate:ident, required) => $($t:tt)*) => {
(($return_mode:ident, $maybe_backdate:ident, required) => $($t:tt)*) => {
};
}

View file

@ -0,0 +1,104 @@
/// Generate the expression for the return type, depending on the return mode defined in [`salsa-macros::options::Options::returns`]
///
/// Used when generating field getters.
#[macro_export]
macro_rules! return_mode_expression {
(
(copy, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
*$field_ref_expr
};
(
(clone, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
::core::clone::Clone::clone($field_ref_expr)
};
(
(ref, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
$field_ref_expr
};
(
(deref, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
::core::ops::Deref::deref($field_ref_expr)
};
(
(as_ref, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
::salsa::SalsaAsRef::as_ref($field_ref_expr)
};
(
(as_deref, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
::salsa::SalsaAsDeref::as_deref($field_ref_expr)
};
}
#[macro_export]
macro_rules! return_mode_ty {
(
(copy, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
$field_ty
};
(
(clone, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
$field_ty
};
(
(ref, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
& $db_lt $field_ty
};
(
(deref, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
& $db_lt <$field_ty as ::core::ops::Deref>::Target
};
(
(as_ref, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
<$field_ty as ::salsa::SalsaAsRef>::AsRef<$db_lt>
};
(
(as_deref, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
<$field_ty as ::salsa::SalsaAsDeref>::AsDeref<$db_lt>
};
}

View file

@ -182,7 +182,7 @@ macro_rules! setup_input_struct {
}
$(
$field_getter_vis fn $field_getter_id<'db, $Db>(self, db: &'db $Db) -> $zalsa::maybe_cloned_ty!($field_option, 'db, $field_ty)
$field_getter_vis fn $field_getter_id<'db, $Db>(self, db: &'db $Db) -> $zalsa::return_mode_ty!($field_option, 'db, $field_ty)
where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database,
@ -192,7 +192,7 @@ macro_rules! setup_input_struct {
self,
$field_index,
);
$zalsa::maybe_clone!(
$zalsa::return_mode_expression!(
$field_option,
$field_ty,
&fields.$field_index,

View file

@ -215,13 +215,13 @@ macro_rules! setup_interned_struct {
}
$(
$field_getter_vis fn $field_getter_id<$Db>(self, db: &'db $Db) -> $zalsa::maybe_cloned_ty!($field_option, 'db, $field_ty)
$field_getter_vis fn $field_getter_id<$Db>(self, db: &'db $Db) -> $zalsa::return_mode_ty!($field_option, 'db, $field_ty)
where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database,
{
let fields = $Configuration::ingredient(db).fields(db.as_dyn_database(), self);
$zalsa::maybe_clone!(
$zalsa::return_mode_expression!(
$field_option,
$field_ty,
&fields.$field_index,

View file

@ -55,8 +55,8 @@ macro_rules! setup_tracked_fn {
// LRU capacity (a literal, maybe 0)
lru: $lru:tt,
// True if we `return_ref` flag was given to the function
return_ref: $return_ref:tt,
// The return mode for the function, see `salsa_macros::options::Option::returns`
return_mode: $return_mode:tt,
assert_return_type_is_update: {$($assert_return_type_is_update:tt)*},
@ -80,13 +80,7 @@ macro_rules! setup_tracked_fn {
$vis fn $fn_name<$db_lt>(
$db: &$db_lt dyn $Db,
$($input_id: $input_ty,)*
) -> salsa::plumbing::macro_if! {
if $return_ref {
&$db_lt $output_ty
} else {
$output_ty
}
} {
) -> salsa::plumbing::return_mode_ty!(($return_mode, __, __), $db_lt, $output_ty) {
use salsa::plumbing as $zalsa;
struct $Configuration;
@ -372,13 +366,7 @@ macro_rules! setup_tracked_fn {
}
};
$zalsa::macro_if! {
if $return_ref {
result
} else {
<$output_ty as std::clone::Clone>::clone(result)
}
}
$zalsa::return_mode_expression!(($return_mode, __, __), $output_ty, result,)
})
}
// The struct needs be last in the macro expansion in order to make the tracked

View file

@ -52,24 +52,24 @@ macro_rules! setup_tracked_struct {
// A set of "field options" for each tracked field.
//
// Each field option is a tuple `(maybe_clone, maybe_backdate)` where:
// Each field option is a tuple `(return_mode, maybe_backdate)` where:
//
// * `maybe_clone` is either the identifier `clone` or `no_clone`
// * `return_mode` is an identifier as specified in `salsa_macros::options::Option::returns`
// * `maybe_backdate` is either the identifier `backdate` or `no_backdate`
//
// These are used to drive conditional logic for each field via recursive macro invocation
// (see e.g. @maybe_clone below).
// (see e.g. @return_mode below).
tracked_options: [$($tracked_option:tt),*],
// A set of "field options" for each untracked field.
//
// Each field option is a tuple `(maybe_clone, maybe_backdate)` where:
// Each field option is a tuple `(return_mode, maybe_backdate)` where:
//
// * `maybe_clone` is either the identifier `clone` or `no_clone`
// * `return_mode` is an identifier as specified in `salsa_macros::options::Option::returns`
// * `maybe_backdate` is either the identifier `backdate` or `no_backdate`
//
// These are used to drive conditional logic for each field via recursive macro invocation
// (see e.g. @maybe_clone below).
// (see e.g. @return_mode below).
untracked_options: [$($untracked_option:tt),*],
// Number of tracked fields.
@ -260,14 +260,14 @@ macro_rules! setup_tracked_struct {
}
$(
$tracked_getter_vis fn $tracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::maybe_cloned_ty!($tracked_option, $db_lt, $tracked_ty)
$tracked_getter_vis fn $tracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::return_mode_ty!($tracked_option, $db_lt, $tracked_ty)
where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database,
{
let db = db.as_dyn_database();
let fields = $Configuration::ingredient(db).tracked_field(db, self, $relative_tracked_index);
$crate::maybe_clone!(
$crate::return_mode_expression!(
$tracked_option,
$tracked_ty,
&fields.$absolute_tracked_index,
@ -276,14 +276,14 @@ macro_rules! setup_tracked_struct {
)*
$(
$untracked_getter_vis fn $untracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::maybe_cloned_ty!($untracked_option, $db_lt, $untracked_ty)
$untracked_getter_vis fn $untracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::return_mode_ty!($untracked_option, $db_lt, $untracked_ty)
where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database,
{
let db = db.as_dyn_database();
let fields = $Configuration::ingredient(db).untracked_field(db, self);
$crate::maybe_clone!(
$crate::return_mode_expression!(
$untracked_option,
$untracked_ty,
&fields.$absolute_untracked_index,

View file

@ -29,7 +29,7 @@ pub(crate) fn accumulator(
struct Accumulator;
impl AllowedOptions for Accumulator {
const RETURN_REF: bool = false;
const RETURNS: bool = false;
const SPECIFY: bool = false;
const NO_EQ: bool = false;
const DEBUG: bool = false;

View file

@ -33,7 +33,7 @@ type InputArgs = Options<InputStruct>;
struct InputStruct;
impl crate::options::AllowedOptions for InputStruct {
const RETURN_REF: bool = false;
const RETURNS: bool = false;
const SPECIFY: bool = false;

View file

@ -33,7 +33,7 @@ type InternedArgs = Options<InternedStruct>;
struct InternedStruct;
impl crate::options::AllowedOptions for InternedStruct {
const RETURN_REF: bool = false;
const RETURNS: bool = false;
const SPECIFY: bool = false;

View file

@ -1,6 +1,7 @@
use std::marker::PhantomData;
use syn::ext::IdentExt;
use syn::parenthesized;
use syn::spanned::Spanned;
/// "Options" are flags that can be supplied to the various salsa related
@ -10,10 +11,11 @@ use syn::spanned::Spanned;
/// trait.
#[derive(Debug)]
pub(crate) struct Options<A: AllowedOptions> {
/// The `return_ref` option is used to signal that field/return type is "by ref"
/// The `returns` option is used to configure the "return mode" for the field/function.
/// This may be one of `copy`, `clone`, `ref`, `as_ref`, `as_deref`.
///
/// If this is `Some`, the value is the `ref` identifier.
pub return_ref: Option<syn::Ident>,
/// If this is `Some`, the value is the ident representing the selected mode.
pub returns: Option<syn::Ident>,
/// The `no_eq` option is used to signal that a given field does not implement
/// the `Eq` trait and cannot be compared for equality.
@ -96,7 +98,7 @@ pub(crate) struct Options<A: AllowedOptions> {
impl<A: AllowedOptions> Default for Options<A> {
fn default() -> Self {
Self {
return_ref: Default::default(),
returns: Default::default(),
specify: Default::default(),
no_eq: Default::default(),
debug: Default::default(),
@ -118,7 +120,7 @@ impl<A: AllowedOptions> Default for Options<A> {
/// These flags determine which options are allowed in a given context
pub(crate) trait AllowedOptions {
const RETURN_REF: bool;
const RETURNS: bool;
const SPECIFY: bool;
const NO_EQ: bool;
const DEBUG: bool;
@ -144,18 +146,21 @@ impl<A: AllowedOptions> syn::parse::Parse for Options<A> {
while !input.is_empty() {
let ident: syn::Ident = syn::Ident::parse_any(input)?;
if ident == "return_ref" {
if A::RETURN_REF {
if let Some(old) = options.return_ref.replace(ident) {
if ident == "returns" {
let content;
parenthesized!(content in input);
let mode = syn::Ident::parse_any(&content)?;
if A::RETURNS {
if let Some(old) = options.returns.replace(mode) {
return Err(syn::Error::new(
old.span(),
"option `return_ref` provided twice",
"option `returns` provided twice",
));
}
} else {
return Err(syn::Error::new(
ident.span(),
"`return_ref` option not allowed here",
"`returns` option not allowed here",
));
}
} else if ident == "no_eq" {

View file

@ -26,6 +26,7 @@
//! * this could be optimized, particularly for interned fields
use proc_macro2::{Ident, Literal, Span, TokenStream};
use syn::{ext::IdentExt, spanned::Spanned};
use crate::db_lifetime;
use crate::options::{AllowedOptions, Options};
@ -58,19 +59,22 @@ pub(crate) struct SalsaField<'s> {
pub(crate) has_tracked_attr: bool,
pub(crate) has_default_attr: bool,
pub(crate) has_ref_attr: bool,
pub(crate) returns: syn::Ident,
pub(crate) has_no_eq_attr: bool,
get_name: syn::Ident,
set_name: syn::Ident,
}
const BANNED_FIELD_NAMES: &[&str] = &["from", "new"];
const ALLOWED_RETURN_MODES: &[&str] = &["copy", "clone", "ref", "deref", "as_ref", "as_deref"];
#[allow(clippy::type_complexity)]
pub(crate) const FIELD_OPTION_ATTRIBUTES: &[(&str, fn(&syn::Attribute, &mut SalsaField))] = &[
("tracked", |_, ef| ef.has_tracked_attr = true),
("default", |_, ef| ef.has_default_attr = true),
("return_ref", |_, ef| ef.has_ref_attr = true),
("returns", |attr, ef| {
ef.returns = attr.parse_args_with(syn::Ident::parse_any).unwrap();
}),
("no_eq", |_, ef| ef.has_no_eq_attr = true),
("get", |attr, ef| {
ef.get_name = attr.parse_args().unwrap();
@ -364,10 +368,11 @@ impl<'s> SalsaField<'s> {
let get_name = Ident::new(&field_name_str, field_name.span());
let set_name = Ident::new(&format!("set_{field_name_str}",), field_name.span());
let returns = Ident::new("clone", field.span());
let mut result = SalsaField {
field,
has_tracked_attr: false,
has_ref_attr: false,
returns,
has_default_attr: false,
has_no_eq_attr: false,
get_name,
@ -383,15 +388,22 @@ impl<'s> SalsaField<'s> {
}
}
// Validate return mode
if !ALLOWED_RETURN_MODES
.iter()
.any(|mode| mode == &result.returns.to_string())
{
return Err(syn::Error::new(
result.returns.span(),
format!("Invalid return mode. Allowed modes are: {ALLOWED_RETURN_MODES:?}"),
));
}
Ok(result)
}
fn options(&self) -> TokenStream {
let clone_ident = if self.has_ref_attr {
syn::Ident::new("no_clone", Span::call_site())
} else {
syn::Ident::new("clone", Span::call_site())
};
let returns = &self.returns;
let backdate_ident = if self.has_no_eq_attr {
syn::Ident::new("no_backdate", Span::call_site())
@ -405,6 +417,6 @@ impl<'s> SalsaField<'s> {
syn::Ident::new("required", Span::call_site())
};
quote!((#clone_ident, #backdate_ident, #default_ident))
quote!((#returns, #backdate_ident, #default_ident))
}
}

View file

@ -1,7 +1,7 @@
use proc_macro2::{Literal, Span, TokenStream};
use quote::ToTokens;
use syn::spanned::Spanned;
use syn::ItemFn;
use syn::{Ident, ItemFn};
use crate::hygiene::Hygiene;
use crate::options::Options;
@ -26,7 +26,7 @@ pub type FnArgs = Options<TrackedFn>;
pub struct TrackedFn;
impl crate::options::AllowedOptions for TrackedFn {
const RETURN_REF: bool = true;
const RETURNS: bool = true;
const SPECIFY: bool = true;
@ -67,6 +67,8 @@ struct ValidFn<'item> {
db_path: &'item syn::Path,
}
const ALLOWED_RETURN_MODES: &[&str] = &["copy", "clone", "ref", "deref", "as_ref", "as_deref"];
#[allow(non_snake_case)]
impl Macro {
fn try_fn(&self, item: syn::ItemFn) -> syn::Result<TokenStream> {
@ -146,7 +148,22 @@ impl Macro {
let lru = Literal::usize_unsuffixed(self.args.lru.unwrap_or(0));
let return_ref: bool = self.args.return_ref.is_some();
let return_mode = self
.args
.returns
.clone()
.unwrap_or(Ident::new("clone", Span::call_site()));
// Validate return mode
if !ALLOWED_RETURN_MODES
.iter()
.any(|mode| mode == &return_mode.to_string())
{
return Err(syn::Error::new(
return_mode.span(),
format!("Invalid return mode. Allowed modes are: {ALLOWED_RETURN_MODES:?}"),
));
}
// The path expression is responsible for emitting the primary span in the diagnostic we
// want, so by uniformly using `output_ty.span()` we ensure that the diagnostic is emitted
@ -183,7 +200,7 @@ impl Macro {
values_equal: {#eq},
needs_interner: #needs_interner,
lru: #lru,
return_ref: #return_ref,
return_mode: #return_mode,
assert_return_type_is_update: { #assert_return_type_is_update },
unused_names: [
#zalsa,

View file

@ -296,13 +296,13 @@ impl Macro {
args: &FnArgs,
db_lt: &Option<syn::Lifetime>,
) -> syn::Result<()> {
if let Some(return_ref) = &args.return_ref {
if let Some(returns) = &args.returns {
if let syn::ReturnType::Type(_, t) = &mut sig.output {
**t = parse_quote!(& #db_lt #t)
} else {
return Err(syn::Error::new_spanned(
return_ref,
"return_ref attribute requires explicit return type",
returns,
"returns attribute requires explicit return type",
));
};
}

View file

@ -28,7 +28,7 @@ type TrackedArgs = Options<TrackedStruct>;
struct TrackedStruct;
impl crate::options::AllowedOptions for TrackedStruct {
const RETURN_REF: bool = false;
const RETURNS: bool = false;
const SPECIFY: bool = false;

View file

@ -5,7 +5,7 @@ use ordered_float::OrderedFloat;
// ANCHOR: input
#[salsa::input(debug)]
pub struct SourceProgram {
#[return_ref]
#[returns(ref)]
pub text: String,
}
// ANCHOR_END: input
@ -13,13 +13,13 @@ pub struct SourceProgram {
// ANCHOR: interned_ids
#[salsa::interned(debug)]
pub struct VariableId<'db> {
#[return_ref]
#[returns(ref)]
pub text: String,
}
#[salsa::interned(debug)]
pub struct FunctionId<'db> {
#[return_ref]
#[returns(ref)]
pub text: String,
}
// ANCHOR_END: interned_ids
@ -28,7 +28,7 @@ pub struct FunctionId<'db> {
#[salsa::tracked(debug)]
pub struct Program<'db> {
#[tracked]
#[return_ref]
#[returns(ref)]
pub statements: Vec<Statement<'db>>,
}
// ANCHOR_END: program
@ -93,11 +93,11 @@ pub struct Function<'db> {
name_span: Span<'db>,
#[tracked]
#[return_ref]
#[returns(ref)]
pub args: Vec<VariableId<'db>>,
#[tracked]
#[return_ref]
#[returns(ref)]
pub body: Expression<'db>,
}
// ANCHOR_END: functions

View file

@ -67,7 +67,7 @@ fn main() -> Result<()> {
#[salsa::input]
struct File {
path: PathBuf,
#[return_ref]
#[returns(ref)]
contents: String,
}
@ -158,7 +158,7 @@ impl Diagnostic {
#[salsa::tracked]
struct ParsedFile<'db> {
value: u32,
#[return_ref]
#[returns(ref)]
links: Vec<ParsedFile<'db>>,
}

View file

@ -22,6 +22,7 @@ mod memo_ingredient_indices;
mod nonce;
#[cfg(feature = "rayon")]
mod parallel;
mod return_mode;
mod revision;
mod runtime;
mod salsa_struct;
@ -49,6 +50,8 @@ pub use self::event::{Event, EventKind};
pub use self::id::Id;
pub use self::input::setter::Setter;
pub use self::key::DatabaseKeyIndex;
pub use self::return_mode::SalsaAsDeref;
pub use self::return_mode::SalsaAsRef;
pub use self::revision::Revision;
pub use self::runtime::Runtime;
pub use self::storage::{Storage, StorageHandle};
@ -71,9 +74,9 @@ pub mod plumbing {
pub use std::option::Option::{self, None, Some};
pub use salsa_macro_rules::{
macro_if, maybe_backdate, maybe_clone, maybe_cloned_ty, maybe_default, maybe_default_tt,
setup_accumulator_impl, setup_input_struct, setup_interned_struct, setup_method_body,
setup_tracked_fn, setup_tracked_struct, unexpected_cycle_initial,
macro_if, maybe_backdate, maybe_default, maybe_default_tt, return_mode_expression,
return_mode_ty, setup_accumulator_impl, setup_input_struct, setup_interned_struct,
setup_method_body, setup_tracked_fn, setup_tracked_struct, unexpected_cycle_initial,
unexpected_cycle_recovery,
};

69
src/return_mode.rs Normal file
View file

@ -0,0 +1,69 @@
//! User-implementable salsa traits for refining the return type via `returns(as_ref)` and `returns(as_deref)`.
use std::ops::Deref;
/// Used to determine the return type and value for tracked fields and functions annotated with `returns(as_ref)`.
pub trait SalsaAsRef {
// The type returned by tracked fields and functions annotated with `returns(as_ref)`.
type AsRef<'a>
where
Self: 'a;
// The value returned by tracked fields and functions annotated with `returns(as_ref)`.
fn as_ref(&self) -> Self::AsRef<'_>;
}
impl<T> SalsaAsRef for Option<T> {
type AsRef<'a>
= Option<&'a T>
where
Self: 'a;
fn as_ref(&self) -> Self::AsRef<'_> {
self.as_ref()
}
}
impl<T, E> SalsaAsRef for Result<T, E> {
type AsRef<'a>
= Result<&'a T, &'a E>
where
Self: 'a;
fn as_ref(&self) -> Self::AsRef<'_> {
self.as_ref()
}
}
/// Used to determine the return type and value for tracked fields and functions annotated with `returns(as_deref)`.
pub trait SalsaAsDeref {
// The type returned by tracked fields and functions annotated with `returns(as_deref)`.
type AsDeref<'a>
where
Self: 'a;
// The value returned by tracked fields and functions annotated with `returns(as_deref)`.
fn as_deref(&self) -> Self::AsDeref<'_>;
}
impl<T: Deref> SalsaAsDeref for Option<T> {
type AsDeref<'a>
= Option<&'a T::Target>
where
Self: 'a;
fn as_deref(&self) -> Self::AsDeref<'_> {
self.as_deref()
}
}
impl<T: Deref, E> SalsaAsDeref for Result<T, E> {
type AsDeref<'a>
= Result<&'a T::Target, &'a E>
where
Self: 'a;
fn as_deref(&self) -> Self::AsDeref<'_> {
self.as_deref()
}
}

View file

@ -37,7 +37,7 @@ fn compute(db: &dyn LogDatabase, input: List) -> u32 {
result
}
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
fn accumulated(db: &dyn LogDatabase, input: List) -> Vec<u32> {
db.push_log(format!("accumulated({input:?})"));
compute::accumulated::<Integers>(db, input)

View file

@ -1,4 +1,4 @@
#[salsa::accumulator(return_ref)]
#[salsa::accumulator(returns(ref))]
struct AccWithRetRef(u32);
#[salsa::accumulator(specify)]

View file

@ -1,8 +1,8 @@
error: `return_ref` option not allowed here
error: `returns` option not allowed here
--> tests/compile-fail/accumulator_incompatibles.rs:1:22
|
1 | #[salsa::accumulator(return_ref)]
| ^^^^^^^^^^
1 | #[salsa::accumulator(returns(ref))]
| ^^^^^^^
error: `specify` option not allowed here
--> tests/compile-fail/accumulator_incompatibles.rs:4:22

View file

@ -1,4 +1,4 @@
#[salsa::input(return_ref)]
#[salsa::input(returns(ref))]
struct InputWithRetRef(u32);
#[salsa::input(specify)]

View file

@ -1,8 +1,8 @@
error: `return_ref` option not allowed here
error: `returns` option not allowed here
--> tests/compile-fail/input_struct_incompatibles.rs:1:16
|
1 | #[salsa::input(return_ref)]
| ^^^^^^^^^^
1 | #[salsa::input(returns(ref))]
| ^^^^^^^
error: `specify` option not allowed here
--> tests/compile-fail/input_struct_incompatibles.rs:4:16

View file

@ -1,4 +1,4 @@
#[salsa::interned(return_ref)]
#[salsa::interned(returns(ref))]
struct InternedWithRetRef {
field: u32,
}

View file

@ -1,8 +1,8 @@
error: `return_ref` option not allowed here
error: `returns` option not allowed here
--> tests/compile-fail/interned_struct_incompatibles.rs:1:19
|
1 | #[salsa::interned(return_ref)]
| ^^^^^^^^^^
1 | #[salsa::interned(returns(ref))]
| ^^^^^^^
error: `specify` option not allowed here
--> tests/compile-fail/interned_struct_incompatibles.rs:6:19

View file

@ -0,0 +1,20 @@
use salsa::Database as Db;
#[salsa::input]
struct MyInput {
#[returns(clone)]
text: String,
}
#[salsa::tracked(returns(not_a_return_mode))]
fn tracked_fn_invalid_return_mode(db: &dyn Db, input: MyInput) -> String {
input.text(db)
}
#[salsa::input]
struct MyInvalidInput {
#[returns(not_a_return_mode)]
text: String,
}
fn main() { }

View file

@ -0,0 +1,17 @@
error: Invalid return mode. Allowed modes are: ["copy", "clone", "ref", "deref", "as_ref", "as_deref"]
--> tests/compile-fail/invalid_return_mode.rs:9:26
|
9 | #[salsa::tracked(returns(not_a_return_mode))]
| ^^^^^^^^^^^^^^^^^
error: Invalid return mode. Allowed modes are: ["copy", "clone", "ref", "deref", "as_ref", "as_deref"]
--> tests/compile-fail/invalid_return_mode.rs:16:15
|
16 | #[returns(not_a_return_mode)]
| ^^^^^^^^^^^^^^^^^
error: cannot find attribute `returns` in this scope
--> tests/compile-fail/invalid_return_mode.rs:16:7
|
16 | #[returns(not_a_return_mode)]
| ^^^^^^^

View file

@ -2,7 +2,7 @@ use salsa::Database as Db;
#[salsa::input]
struct MyInput {
#[return_ref]
#[returns(ref)]
text: String,
}

View file

@ -3,7 +3,7 @@ struct MyTracked<'db> {
field: u32,
}
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
impl<'db> std::default::Default for MyTracked<'db> {
fn default() -> Self {}
}

View file

@ -1,8 +1,8 @@
error: unexpected token
--> tests/compile-fail/tracked_impl_incompatibles.rs:6:18
|
6 | #[salsa::tracked(return_ref)]
| ^^^^^^^^^^
6 | #[salsa::tracked(returns(ref))]
| ^^^^^^^
error: unexpected token
--> tests/compile-fail/tracked_impl_incompatibles.rs:11:18

View file

@ -1,4 +1,4 @@
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
struct TrackedWithRetRef {
field: u32,
}

View file

@ -1,8 +1,8 @@
error: `return_ref` option not allowed here
error: `returns` option not allowed here
--> tests/compile-fail/tracked_struct_incompatibles.rs:1:18
|
1 | #[salsa::tracked(return_ref)]
| ^^^^^^^^^^
1 | #[salsa::tracked(returns(ref))]
| ^^^^^^^
error: `specify` option not allowed here
--> tests/compile-fail/tracked_struct_incompatibles.rs:6:18

View file

@ -32,7 +32,7 @@ impl Value {
/// `max_iterate`, `min_panic`, `max_panic`) for testing cycle behaviors.
#[salsa::input]
struct Inputs {
#[return_ref]
#[returns(ref)]
inputs: Vec<Input>,
}

View file

@ -12,7 +12,6 @@ struct Input {
#[salsa::interned(debug)]
struct Output<'db> {
#[return_ref]
value: u32,
}
@ -170,7 +169,7 @@ fn nested_cycle_fewer_dependencies_in_first_iteration() {
}
}
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
fn index<'db>(db: &'db dyn salsa::Database, input: Input) -> Index<'db> {
Index {
scope: Scope::new(db, input.value(db) * 2),

View file

@ -30,10 +30,10 @@ struct Edge {
#[salsa::tracked(debug)]
struct Node<'db> {
#[return_ref]
#[returns(ref)]
name: String,
#[return_ref]
#[returns(deref)]
#[tracked]
edges: Vec<Edge>,
@ -45,7 +45,7 @@ struct GraphInput {
simple: bool,
}
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
fn create_graph(db: &dyn salsa::Database, input: GraphInput) -> Graph<'_> {
if input.simple(db) {
let a = Node::new(db, "a".to_string(), vec![], input);

View file

@ -17,7 +17,7 @@ struct MyTracked<'db> {
identifier: u32,
#[tracked]
#[return_ref]
#[returns(ref)]
field: Bomb,
}

View file

@ -34,7 +34,7 @@ impl MyInput {
#[salsa::interned(constructor = from_string)]
struct MyInterned<'db> {
#[get(text)]
#[return_ref]
#[returns(ref)]
field: String,
}

172
tests/return_mode.rs Normal file
View file

@ -0,0 +1,172 @@
use salsa::Database;
#[salsa::input]
struct DefaultInput {
text: String,
}
#[salsa::tracked]
fn default_fn(db: &dyn Database, input: DefaultInput) -> String {
let input: String = input.text(db);
input
}
#[test]
fn default_test() {
salsa::DatabaseImpl::new().attach(|db| {
let input = DefaultInput::new(db, "Test".into());
let x: String = default_fn(db, input);
expect_test::expect![[r#"
"Test"
"#]]
.assert_debug_eq(&x);
})
}
#[salsa::input]
struct CopyInput {
#[returns(copy)]
text: &'static str,
}
#[salsa::tracked(returns(copy))]
fn copy_fn(db: &dyn Database, input: CopyInput) -> &'static str {
let input: &'static str = input.text(db);
input
}
#[test]
fn copy_test() {
salsa::DatabaseImpl::new().attach(|db| {
let input = CopyInput::new(db, "Test");
let x: &str = copy_fn(db, input);
expect_test::expect![[r#"
"Test"
"#]]
.assert_debug_eq(&x);
})
}
#[salsa::input]
struct CloneInput {
#[returns(clone)]
text: String,
}
#[salsa::tracked(returns(clone))]
fn clone_fn(db: &dyn Database, input: CloneInput) -> String {
let input: String = input.text(db);
input
}
#[test]
fn clone_test() {
salsa::DatabaseImpl::new().attach(|db| {
let input = CloneInput::new(db, "Test".into());
let x: String = clone_fn(db, input);
expect_test::expect![[r#"
"Test"
"#]]
.assert_debug_eq(&x);
})
}
#[salsa::input]
struct RefInput {
#[returns(ref)]
text: String,
}
#[salsa::tracked(returns(ref))]
fn ref_fn(db: &dyn Database, input: RefInput) -> String {
let input: &String = input.text(db);
input.to_owned()
}
#[test]
fn ref_test() {
salsa::DatabaseImpl::new().attach(|db| {
let input = RefInput::new(db, "Test".into());
let x: &String = ref_fn(db, input);
expect_test::expect![[r#"
"Test"
"#]]
.assert_debug_eq(&x);
})
}
#[salsa::input]
struct DerefInput {
#[returns(deref)]
text: String,
}
#[salsa::tracked(returns(deref))]
fn deref_fn(db: &dyn Database, input: DerefInput) -> String {
let input: &str = input.text(db);
input.to_owned()
}
#[test]
fn deref_test() {
salsa::DatabaseImpl::new().attach(|db| {
let input = DerefInput::new(db, "Test".into());
let x: &str = deref_fn(db, input);
expect_test::expect![[r#"
"Test"
"#]]
.assert_debug_eq(&x);
})
}
#[salsa::input]
struct AsRefInput {
#[returns(as_ref)]
text: Option<String>,
}
#[salsa::tracked(returns(as_ref))]
fn as_ref_fn(db: &dyn Database, input: AsRefInput) -> Option<String> {
let input: Option<&String> = input.text(db);
input.cloned()
}
#[test]
fn as_ref_test() {
salsa::DatabaseImpl::new().attach(|db| {
let input = AsRefInput::new(db, Some("Test".into()));
let x: Option<&String> = as_ref_fn(db, input);
expect_test::expect![[r#"
Some(
"Test",
)
"#]]
.assert_debug_eq(&x);
})
}
#[salsa::input]
struct AsDerefInput {
#[returns(as_deref)]
text: Option<String>,
}
#[salsa::tracked(returns(as_deref))]
fn as_deref_fn(db: &dyn Database, input: AsDerefInput) -> Option<String> {
let input: Option<&str> = input.text(db);
input.map(|s| s.to_owned())
}
#[test]
fn as_deref_test() {
salsa::DatabaseImpl::new().attach(|db| {
let input = AsDerefInput::new(db, Some("Test".into()));
let x: Option<&str> = as_deref_fn(db, input);
expect_test::expect![[r#"
Some(
"Test",
)
"#]]
.assert_debug_eq(&x);
})
}

View file

@ -5,7 +5,7 @@ struct Input {
number: usize,
}
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
fn test(db: &dyn salsa::Database, input: Input) -> Vec<String> {
(0..input.number(db)).map(|i| format!("test {i}")).collect()
}

View file

@ -23,7 +23,7 @@ impl MyInput {
self.field(db) * 2
}
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
fn tracked_fn_ref(self, db: &dyn salsa::Database) -> u32 {
self.field(db) * 3
}

View file

@ -7,7 +7,7 @@ struct Input {
#[salsa::tracked]
impl Input {
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
fn test(self, db: &dyn salsa::Database) -> Vec<String> {
(0..self.number(db)).map(|i| format!("test {i}")).collect()
}

View file

@ -23,7 +23,7 @@ pub struct SourceTree<'db> {
#[salsa::tracked]
impl<'db1> SourceTree<'db1> {
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
pub fn inherent_item_name(self, db: &'db1 dyn Database) -> String {
self.name(db)
}
@ -35,7 +35,7 @@ trait ItemName<'db1> {
#[salsa::tracked]
impl<'db1> ItemName<'db1> for SourceTree<'db1> {
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
fn trait_item_name(self, db: &'db1 dyn Database) -> String {
self.name(db)
}

View file

@ -11,7 +11,7 @@ trait Trait {
#[salsa::tracked]
impl Trait for Input {
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
fn test(self, db: &dyn salsa::Database) -> Vec<String> {
(0..self.number(db)).map(|i| format!("test {i}")).collect()
}

View file

@ -3,6 +3,6 @@ enum Token {}
#[salsa::tracked]
struct TokenTree<'db> {
#[return_ref]
#[returns(ref)]
tokens: Vec<Token>,
}

View file

@ -9,13 +9,13 @@ pub struct SourceTree<'db> {}
#[salsa::tracked]
impl<'db> SourceTree<'db> {
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
pub fn all_items(self, _db: &'db dyn Db) -> Vec<Item> {
todo!()
}
}
#[salsa::tracked(return_ref)]
#[salsa::tracked(returns(ref))]
fn use_tree<'db>(_db: &'db dyn Db, _tree: SourceTree<'db>) {}
#[allow(unused)]