Changed return_ref syntax to returns(as_ref) and returns(cloned) (#772)

* Changed `return_ref` syntax to `returns(as_ref)` and `returns(cloned)`

* Implement

* renamed module for return_mode

* Rename macro, fix docs, add tests, validate return modes

* Cargo fmt

---------

Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
CheaterCodes 2025-05-09 09:28:54 +02:00 committed by GitHub
parent d1da99132d
commit 13a2bd7461
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
54 changed files with 538 additions and 172 deletions

View file

@ -10,7 +10,7 @@ include!("shims/global_alloc_overwrite.rs");
#[salsa::input] #[salsa::input]
pub struct Input { pub struct Input {
#[return_ref] #[returns(ref)]
pub text: String, pub text: String,
} }
@ -22,7 +22,7 @@ pub fn length(db: &dyn salsa::Database, input: Input) -> usize {
#[salsa::interned] #[salsa::interned]
pub struct InternedInput<'db> { pub struct InternedInput<'db> {
#[return_ref] #[returns(ref)]
pub text: String, pub text: String,
} }

View file

@ -15,7 +15,7 @@ struct Tracked<'db> {
number: usize, number: usize,
} }
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
#[inline(never)] #[inline(never)]
fn index<'db>(db: &'db dyn salsa::Database, input: Input) -> Vec<Tracked<'db>> { fn index<'db>(db: &'db dyn salsa::Database, input: Input) -> Vec<Tracked<'db>> {
(0..input.field(db)).map(|i| Tracked::new(db, i)).collect() (0..input.field(db)).map(|i| Tracked::new(db, i)).collect()

View file

@ -79,7 +79,7 @@ pub struct ProgramFile(salsa::Id);
This means that, when you have a `ProgramFile`, you can easily copy it around and put it wherever you like. This means that, when you have a `ProgramFile`, you can easily copy it around and put it wherever you like.
To actually read any of its fields, however, you will need to use the database and a getter method. To actually read any of its fields, however, you will need to use the database and a getter method.
### Reading fields and `return_ref` ### Reading fields and `returns(ref)`
You can access the value of an input's fields by using the getter method. You can access the value of an input's fields by using the getter method.
As this is only reading the field, it just needs a `&`-reference to the database: As this is only reading the field, it just needs a `&`-reference to the database:
@ -89,13 +89,13 @@ let contents: String = file.contents(&db);
``` ```
Invoking the accessor clones the value from the database. Invoking the accessor clones the value from the database.
Sometimes this is not what you want, so you can annotate fields with `#[return_ref]` to indicate that they should return a reference into the database instead: Sometimes this is not what you want, so you can annotate fields with `#[returns(ref)]` to indicate that they should return a reference into the database instead:
```rust ```rust
#[salsa::input] #[salsa::input]
pub struct ProgramFile { pub struct ProgramFile {
pub path: PathBuf, pub path: PathBuf,
#[return_ref] #[returns(ref)]
pub contents: String, pub contents: String,
} }
``` ```
@ -145,7 +145,7 @@ Tracked functions have to follow a particular structure:
- They must take a "Salsa struct" as the second argument -- in our example, this is an input struct, but there are other kinds of Salsa structs we'll describe shortly. - They must take a "Salsa struct" as the second argument -- in our example, this is an input struct, but there are other kinds of Salsa structs we'll describe shortly.
- They _can_ take additional arguments, but it's faster and better if they don't. - They _can_ take additional arguments, but it's faster and better if they don't.
Tracked functions can return any clone-able type. A clone is required since, when the value is cached, the result will be cloned out of the database. Tracked functions can also be annotated with `#[return_ref]` if you would prefer to return a reference into the database instead (if `parse_file` were so annotated, then callers would actually get back an `&Ast`, for example). Tracked functions can return any clone-able type. A clone is required since, when the value is cached, the result will be cloned out of the database. Tracked functions can also be annotated with `#[returns(ref)]` if you would prefer to return a reference into the database instead (if `parse_file` were so annotated, then callers would actually get back an `&Ast`, for example).
## Tracked structs ## Tracked structs
@ -158,7 +158,7 @@ Example:
```rust ```rust
#[salsa::tracked] #[salsa::tracked]
struct Ast<'db> { struct Ast<'db> {
#[return_ref] #[returns(ref)]
top_level_items: Vec<Item>, top_level_items: Vec<Item>,
} }
``` ```
@ -252,7 +252,7 @@ Most compilers, for example, will define a type to represent a user identifier:
```rust ```rust
#[salsa::interned] #[salsa::interned]
struct Word { struct Word {
#[return_ref] #[returns(ref)]
pub text: String, pub text: String,
} }
``` ```
@ -269,7 +269,7 @@ let w3 = Word::new(db, "foo".to_string());
When you create two interned structs with the same field values, you are guaranteed to get back the same integer id. So here, we know that `assert_eq!(w1, w3)` is true and `assert_ne!(w1, w2)`. When you create two interned structs with the same field values, you are guaranteed to get back the same integer id. So here, we know that `assert_eq!(w1, w3)` is true and `assert_ne!(w1, w2)`.
You can access the fields of an interned struct using a getter, like `word.text(db)`. These getters respect the `#[return_ref]` annotation. Like tracked structs, the fields of interned structs are immutable. You can access the fields of an interned struct using a getter, like `word.text(db)`. These getters respect the `#[returns(ref)]` annotation. Like tracked structs, the fields of interned structs are immutable.
## Accumulators ## Accumulators

View file

@ -20,7 +20,7 @@ fn parse_module(db: &dyn Db, module: Module) -> Ast {
Ast::parse_text(module_text) Ast::parse_text(module_text)
} }
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
fn module_text(db: &dyn Db, module: Module) -> String { fn module_text(db: &dyn Db, module: Module) -> String {
panic!("text for module `{module:?}` not set") panic!("text for module `{module:?}` not set")
} }

View file

@ -92,7 +92,7 @@ Apart from the fields being immutable, the API for working with a tracked struct
- You can create a new value by using `new`: e.g., `Program::new(&db, some_statements)` - You can create a new value by using `new`: e.g., `Program::new(&db, some_statements)`
- You use a getter to read the value of a field, just like with an input (e.g., `my_func.statements(db)` to read the `statements` field). - You use a getter to read the value of a field, just like with an input (e.g., `my_func.statements(db)` to read the `statements` field).
- In this case, the field is tagged as `#[return_ref]`, which means that the getter will return a `&Vec<Statement>`, instead of cloning the vector. - In this case, the field is tagged as `#[returns(ref)]`, which means that the getter will return a `&Vec<Statement>`, instead of cloning the vector.
### The `'db` lifetime ### The `'db` lifetime

View file

@ -61,12 +61,12 @@ Tracked functions may take other arguments as well, though our examples here do
Functions that take additional arguments are less efficient and flexible. Functions that take additional arguments are less efficient and flexible.
It's generally better to structure tracked functions as functions of a single Salsa struct if possible. It's generally better to structure tracked functions as functions of a single Salsa struct if possible.
### The `return_ref` annotation ### The `returns(ref)` annotation
You may have noticed that `parse_statements` is tagged with `#[salsa::tracked(return_ref)]`. You may have noticed that `parse_statements` is tagged with `#[salsa::tracked(returns(ref))]`.
Ordinarily, when you call a tracked function, the result you get back is cloned out of the database. Ordinarily, when you call a tracked function, the result you get back is cloned out of the database.
The `return_ref` attribute means that a reference into the database is returned instead. The `returns(ref)` attribute means that a reference into the database is returned instead.
So, when called, `parse_statements` will return an `&Vec<Statement>` rather than cloning the `Vec`. So, when called, `parse_statements` will return an `&Vec<Statement>` rather than cloning the `Vec`.
This is useful as a performance optimization. This is useful as a performance optimization.
(You may recall the `return_ref` annotation from the [ir](./ir.md) section of the tutorial, (You may recall the `returns(ref)` annotation from the [ir](./ir.md) section of the tutorial,
where it was placed on struct fields, with roughly the same meaning.) where it was placed on struct fields, with roughly the same meaning.)

View file

@ -14,8 +14,8 @@
mod macro_if; mod macro_if;
mod maybe_backdate; mod maybe_backdate;
mod maybe_clone;
mod maybe_default; mod maybe_default;
mod return_mode;
mod setup_accumulator_impl; mod setup_accumulator_impl;
mod setup_input_struct; mod setup_input_struct;
mod setup_interned_struct; mod setup_interned_struct;

View file

@ -2,7 +2,7 @@
#[macro_export] #[macro_export]
macro_rules! maybe_backdate { macro_rules! maybe_backdate {
( (
($maybe_clone:ident, no_backdate, $maybe_default:ident), ($return_mode:ident, no_backdate, $maybe_default:ident),
$field_ty:ty, $field_ty:ty,
$old_field_place:expr, $old_field_place:expr,
$new_field_place:expr, $new_field_place:expr,
@ -20,7 +20,7 @@ macro_rules! maybe_backdate {
}; };
( (
($maybe_clone:ident, backdate, $maybe_default:ident), ($return_mode:ident, backdate, $maybe_default:ident),
$field_ty:ty, $field_ty:ty,
$old_field_place:expr, $old_field_place:expr,
$new_field_place:expr, $new_field_place:expr,

View file

@ -1,40 +0,0 @@
/// Generate either `field_ref_expr` or a clone of that expr.
///
/// Used when generating field getters.
#[macro_export]
macro_rules! maybe_clone {
(
(no_clone, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
$field_ref_expr
};
(
(clone, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
std::clone::Clone::clone($field_ref_expr)
};
}
#[macro_export]
macro_rules! maybe_cloned_ty {
(
(no_clone, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
& $db_lt $field_ty
};
(
(clone, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
$field_ty
};
}

View file

@ -4,7 +4,7 @@
#[macro_export] #[macro_export]
macro_rules! maybe_default { macro_rules! maybe_default {
( (
($maybe_clone:ident, $maybe_backdate:ident, default), ($return_mode:ident, $maybe_backdate:ident, default),
$field_ty:ty, $field_ty:ty,
$field_ref_expr:expr, $field_ref_expr:expr,
) => { ) => {
@ -12,7 +12,7 @@ macro_rules! maybe_default {
}; };
( (
($maybe_clone:ident, $maybe_backdate:ident, required), ($return_mode:ident, $maybe_backdate:ident, required),
$field_ty:ty, $field_ty:ty,
$field_ref_expr:expr, $field_ref_expr:expr,
) => { ) => {
@ -22,11 +22,11 @@ macro_rules! maybe_default {
#[macro_export] #[macro_export]
macro_rules! maybe_default_tt { macro_rules! maybe_default_tt {
(($maybe_clone:ident, $maybe_backdate:ident, default) => $($t:tt)*) => { (($return_mode:ident, $maybe_backdate:ident, default) => $($t:tt)*) => {
$($t)* $($t)*
}; };
(($maybe_clone:ident, $maybe_backdate:ident, required) => $($t:tt)*) => { (($return_mode:ident, $maybe_backdate:ident, required) => $($t:tt)*) => {
}; };
} }

View file

@ -0,0 +1,104 @@
/// Generate the expression for the return type, depending on the return mode defined in [`salsa-macros::options::Options::returns`]
///
/// Used when generating field getters.
#[macro_export]
macro_rules! return_mode_expression {
(
(copy, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
*$field_ref_expr
};
(
(clone, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
::core::clone::Clone::clone($field_ref_expr)
};
(
(ref, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
$field_ref_expr
};
(
(deref, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
::core::ops::Deref::deref($field_ref_expr)
};
(
(as_ref, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
::salsa::SalsaAsRef::as_ref($field_ref_expr)
};
(
(as_deref, $maybe_backdate:ident, $maybe_default:ident),
$field_ty:ty,
$field_ref_expr:expr,
) => {
::salsa::SalsaAsDeref::as_deref($field_ref_expr)
};
}
#[macro_export]
macro_rules! return_mode_ty {
(
(copy, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
$field_ty
};
(
(clone, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
$field_ty
};
(
(ref, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
& $db_lt $field_ty
};
(
(deref, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
& $db_lt <$field_ty as ::core::ops::Deref>::Target
};
(
(as_ref, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
<$field_ty as ::salsa::SalsaAsRef>::AsRef<$db_lt>
};
(
(as_deref, $maybe_backdate:ident, $maybe_default:ident),
$db_lt:lifetime,
$field_ty:ty
) => {
<$field_ty as ::salsa::SalsaAsDeref>::AsDeref<$db_lt>
};
}

View file

@ -182,7 +182,7 @@ macro_rules! setup_input_struct {
} }
$( $(
$field_getter_vis fn $field_getter_id<'db, $Db>(self, db: &'db $Db) -> $zalsa::maybe_cloned_ty!($field_option, 'db, $field_ty) $field_getter_vis fn $field_getter_id<'db, $Db>(self, db: &'db $Db) -> $zalsa::return_mode_ty!($field_option, 'db, $field_ty)
where where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database, $Db: ?Sized + $zalsa::Database,
@ -192,7 +192,7 @@ macro_rules! setup_input_struct {
self, self,
$field_index, $field_index,
); );
$zalsa::maybe_clone!( $zalsa::return_mode_expression!(
$field_option, $field_option,
$field_ty, $field_ty,
&fields.$field_index, &fields.$field_index,

View file

@ -215,13 +215,13 @@ macro_rules! setup_interned_struct {
} }
$( $(
$field_getter_vis fn $field_getter_id<$Db>(self, db: &'db $Db) -> $zalsa::maybe_cloned_ty!($field_option, 'db, $field_ty) $field_getter_vis fn $field_getter_id<$Db>(self, db: &'db $Db) -> $zalsa::return_mode_ty!($field_option, 'db, $field_ty)
where where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database, $Db: ?Sized + $zalsa::Database,
{ {
let fields = $Configuration::ingredient(db).fields(db.as_dyn_database(), self); let fields = $Configuration::ingredient(db).fields(db.as_dyn_database(), self);
$zalsa::maybe_clone!( $zalsa::return_mode_expression!(
$field_option, $field_option,
$field_ty, $field_ty,
&fields.$field_index, &fields.$field_index,

View file

@ -55,8 +55,8 @@ macro_rules! setup_tracked_fn {
// LRU capacity (a literal, maybe 0) // LRU capacity (a literal, maybe 0)
lru: $lru:tt, lru: $lru:tt,
// True if we `return_ref` flag was given to the function // The return mode for the function, see `salsa_macros::options::Option::returns`
return_ref: $return_ref:tt, return_mode: $return_mode:tt,
assert_return_type_is_update: {$($assert_return_type_is_update:tt)*}, assert_return_type_is_update: {$($assert_return_type_is_update:tt)*},
@ -80,13 +80,7 @@ macro_rules! setup_tracked_fn {
$vis fn $fn_name<$db_lt>( $vis fn $fn_name<$db_lt>(
$db: &$db_lt dyn $Db, $db: &$db_lt dyn $Db,
$($input_id: $input_ty,)* $($input_id: $input_ty,)*
) -> salsa::plumbing::macro_if! { ) -> salsa::plumbing::return_mode_ty!(($return_mode, __, __), $db_lt, $output_ty) {
if $return_ref {
&$db_lt $output_ty
} else {
$output_ty
}
} {
use salsa::plumbing as $zalsa; use salsa::plumbing as $zalsa;
struct $Configuration; struct $Configuration;
@ -372,13 +366,7 @@ macro_rules! setup_tracked_fn {
} }
}; };
$zalsa::macro_if! { $zalsa::return_mode_expression!(($return_mode, __, __), $output_ty, result,)
if $return_ref {
result
} else {
<$output_ty as std::clone::Clone>::clone(result)
}
}
}) })
} }
// The struct needs be last in the macro expansion in order to make the tracked // The struct needs be last in the macro expansion in order to make the tracked

View file

@ -52,24 +52,24 @@ macro_rules! setup_tracked_struct {
// A set of "field options" for each tracked field. // A set of "field options" for each tracked field.
// //
// Each field option is a tuple `(maybe_clone, maybe_backdate)` where: // Each field option is a tuple `(return_mode, maybe_backdate)` where:
// //
// * `maybe_clone` is either the identifier `clone` or `no_clone` // * `return_mode` is an identifier as specified in `salsa_macros::options::Option::returns`
// * `maybe_backdate` is either the identifier `backdate` or `no_backdate` // * `maybe_backdate` is either the identifier `backdate` or `no_backdate`
// //
// These are used to drive conditional logic for each field via recursive macro invocation // These are used to drive conditional logic for each field via recursive macro invocation
// (see e.g. @maybe_clone below). // (see e.g. @return_mode below).
tracked_options: [$($tracked_option:tt),*], tracked_options: [$($tracked_option:tt),*],
// A set of "field options" for each untracked field. // A set of "field options" for each untracked field.
// //
// Each field option is a tuple `(maybe_clone, maybe_backdate)` where: // Each field option is a tuple `(return_mode, maybe_backdate)` where:
// //
// * `maybe_clone` is either the identifier `clone` or `no_clone` // * `return_mode` is an identifier as specified in `salsa_macros::options::Option::returns`
// * `maybe_backdate` is either the identifier `backdate` or `no_backdate` // * `maybe_backdate` is either the identifier `backdate` or `no_backdate`
// //
// These are used to drive conditional logic for each field via recursive macro invocation // These are used to drive conditional logic for each field via recursive macro invocation
// (see e.g. @maybe_clone below). // (see e.g. @return_mode below).
untracked_options: [$($untracked_option:tt),*], untracked_options: [$($untracked_option:tt),*],
// Number of tracked fields. // Number of tracked fields.
@ -260,14 +260,14 @@ macro_rules! setup_tracked_struct {
} }
$( $(
$tracked_getter_vis fn $tracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::maybe_cloned_ty!($tracked_option, $db_lt, $tracked_ty) $tracked_getter_vis fn $tracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::return_mode_ty!($tracked_option, $db_lt, $tracked_ty)
where where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database, $Db: ?Sized + $zalsa::Database,
{ {
let db = db.as_dyn_database(); let db = db.as_dyn_database();
let fields = $Configuration::ingredient(db).tracked_field(db, self, $relative_tracked_index); let fields = $Configuration::ingredient(db).tracked_field(db, self, $relative_tracked_index);
$crate::maybe_clone!( $crate::return_mode_expression!(
$tracked_option, $tracked_option,
$tracked_ty, $tracked_ty,
&fields.$absolute_tracked_index, &fields.$absolute_tracked_index,
@ -276,14 +276,14 @@ macro_rules! setup_tracked_struct {
)* )*
$( $(
$untracked_getter_vis fn $untracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::maybe_cloned_ty!($untracked_option, $db_lt, $untracked_ty) $untracked_getter_vis fn $untracked_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::return_mode_ty!($untracked_option, $db_lt, $untracked_ty)
where where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + $zalsa::Database, $Db: ?Sized + $zalsa::Database,
{ {
let db = db.as_dyn_database(); let db = db.as_dyn_database();
let fields = $Configuration::ingredient(db).untracked_field(db, self); let fields = $Configuration::ingredient(db).untracked_field(db, self);
$crate::maybe_clone!( $crate::return_mode_expression!(
$untracked_option, $untracked_option,
$untracked_ty, $untracked_ty,
&fields.$absolute_untracked_index, &fields.$absolute_untracked_index,

View file

@ -29,7 +29,7 @@ pub(crate) fn accumulator(
struct Accumulator; struct Accumulator;
impl AllowedOptions for Accumulator { impl AllowedOptions for Accumulator {
const RETURN_REF: bool = false; const RETURNS: bool = false;
const SPECIFY: bool = false; const SPECIFY: bool = false;
const NO_EQ: bool = false; const NO_EQ: bool = false;
const DEBUG: bool = false; const DEBUG: bool = false;

View file

@ -33,7 +33,7 @@ type InputArgs = Options<InputStruct>;
struct InputStruct; struct InputStruct;
impl crate::options::AllowedOptions for InputStruct { impl crate::options::AllowedOptions for InputStruct {
const RETURN_REF: bool = false; const RETURNS: bool = false;
const SPECIFY: bool = false; const SPECIFY: bool = false;

View file

@ -33,7 +33,7 @@ type InternedArgs = Options<InternedStruct>;
struct InternedStruct; struct InternedStruct;
impl crate::options::AllowedOptions for InternedStruct { impl crate::options::AllowedOptions for InternedStruct {
const RETURN_REF: bool = false; const RETURNS: bool = false;
const SPECIFY: bool = false; const SPECIFY: bool = false;

View file

@ -1,6 +1,7 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use syn::ext::IdentExt; use syn::ext::IdentExt;
use syn::parenthesized;
use syn::spanned::Spanned; use syn::spanned::Spanned;
/// "Options" are flags that can be supplied to the various salsa related /// "Options" are flags that can be supplied to the various salsa related
@ -10,10 +11,11 @@ use syn::spanned::Spanned;
/// trait. /// trait.
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct Options<A: AllowedOptions> { pub(crate) struct Options<A: AllowedOptions> {
/// The `return_ref` option is used to signal that field/return type is "by ref" /// The `returns` option is used to configure the "return mode" for the field/function.
/// This may be one of `copy`, `clone`, `ref`, `as_ref`, `as_deref`.
/// ///
/// If this is `Some`, the value is the `ref` identifier. /// If this is `Some`, the value is the ident representing the selected mode.
pub return_ref: Option<syn::Ident>, pub returns: Option<syn::Ident>,
/// The `no_eq` option is used to signal that a given field does not implement /// The `no_eq` option is used to signal that a given field does not implement
/// the `Eq` trait and cannot be compared for equality. /// the `Eq` trait and cannot be compared for equality.
@ -96,7 +98,7 @@ pub(crate) struct Options<A: AllowedOptions> {
impl<A: AllowedOptions> Default for Options<A> { impl<A: AllowedOptions> Default for Options<A> {
fn default() -> Self { fn default() -> Self {
Self { Self {
return_ref: Default::default(), returns: Default::default(),
specify: Default::default(), specify: Default::default(),
no_eq: Default::default(), no_eq: Default::default(),
debug: Default::default(), debug: Default::default(),
@ -118,7 +120,7 @@ impl<A: AllowedOptions> Default for Options<A> {
/// These flags determine which options are allowed in a given context /// These flags determine which options are allowed in a given context
pub(crate) trait AllowedOptions { pub(crate) trait AllowedOptions {
const RETURN_REF: bool; const RETURNS: bool;
const SPECIFY: bool; const SPECIFY: bool;
const NO_EQ: bool; const NO_EQ: bool;
const DEBUG: bool; const DEBUG: bool;
@ -144,18 +146,21 @@ impl<A: AllowedOptions> syn::parse::Parse for Options<A> {
while !input.is_empty() { while !input.is_empty() {
let ident: syn::Ident = syn::Ident::parse_any(input)?; let ident: syn::Ident = syn::Ident::parse_any(input)?;
if ident == "return_ref" { if ident == "returns" {
if A::RETURN_REF { let content;
if let Some(old) = options.return_ref.replace(ident) { parenthesized!(content in input);
let mode = syn::Ident::parse_any(&content)?;
if A::RETURNS {
if let Some(old) = options.returns.replace(mode) {
return Err(syn::Error::new( return Err(syn::Error::new(
old.span(), old.span(),
"option `return_ref` provided twice", "option `returns` provided twice",
)); ));
} }
} else { } else {
return Err(syn::Error::new( return Err(syn::Error::new(
ident.span(), ident.span(),
"`return_ref` option not allowed here", "`returns` option not allowed here",
)); ));
} }
} else if ident == "no_eq" { } else if ident == "no_eq" {

View file

@ -26,6 +26,7 @@
//! * this could be optimized, particularly for interned fields //! * this could be optimized, particularly for interned fields
use proc_macro2::{Ident, Literal, Span, TokenStream}; use proc_macro2::{Ident, Literal, Span, TokenStream};
use syn::{ext::IdentExt, spanned::Spanned};
use crate::db_lifetime; use crate::db_lifetime;
use crate::options::{AllowedOptions, Options}; use crate::options::{AllowedOptions, Options};
@ -58,19 +59,22 @@ pub(crate) struct SalsaField<'s> {
pub(crate) has_tracked_attr: bool, pub(crate) has_tracked_attr: bool,
pub(crate) has_default_attr: bool, pub(crate) has_default_attr: bool,
pub(crate) has_ref_attr: bool, pub(crate) returns: syn::Ident,
pub(crate) has_no_eq_attr: bool, pub(crate) has_no_eq_attr: bool,
get_name: syn::Ident, get_name: syn::Ident,
set_name: syn::Ident, set_name: syn::Ident,
} }
const BANNED_FIELD_NAMES: &[&str] = &["from", "new"]; const BANNED_FIELD_NAMES: &[&str] = &["from", "new"];
const ALLOWED_RETURN_MODES: &[&str] = &["copy", "clone", "ref", "deref", "as_ref", "as_deref"];
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
pub(crate) const FIELD_OPTION_ATTRIBUTES: &[(&str, fn(&syn::Attribute, &mut SalsaField))] = &[ pub(crate) const FIELD_OPTION_ATTRIBUTES: &[(&str, fn(&syn::Attribute, &mut SalsaField))] = &[
("tracked", |_, ef| ef.has_tracked_attr = true), ("tracked", |_, ef| ef.has_tracked_attr = true),
("default", |_, ef| ef.has_default_attr = true), ("default", |_, ef| ef.has_default_attr = true),
("return_ref", |_, ef| ef.has_ref_attr = true), ("returns", |attr, ef| {
ef.returns = attr.parse_args_with(syn::Ident::parse_any).unwrap();
}),
("no_eq", |_, ef| ef.has_no_eq_attr = true), ("no_eq", |_, ef| ef.has_no_eq_attr = true),
("get", |attr, ef| { ("get", |attr, ef| {
ef.get_name = attr.parse_args().unwrap(); ef.get_name = attr.parse_args().unwrap();
@ -364,10 +368,11 @@ impl<'s> SalsaField<'s> {
let get_name = Ident::new(&field_name_str, field_name.span()); let get_name = Ident::new(&field_name_str, field_name.span());
let set_name = Ident::new(&format!("set_{field_name_str}",), field_name.span()); let set_name = Ident::new(&format!("set_{field_name_str}",), field_name.span());
let returns = Ident::new("clone", field.span());
let mut result = SalsaField { let mut result = SalsaField {
field, field,
has_tracked_attr: false, has_tracked_attr: false,
has_ref_attr: false, returns,
has_default_attr: false, has_default_attr: false,
has_no_eq_attr: false, has_no_eq_attr: false,
get_name, get_name,
@ -383,15 +388,22 @@ impl<'s> SalsaField<'s> {
} }
} }
// Validate return mode
if !ALLOWED_RETURN_MODES
.iter()
.any(|mode| mode == &result.returns.to_string())
{
return Err(syn::Error::new(
result.returns.span(),
format!("Invalid return mode. Allowed modes are: {ALLOWED_RETURN_MODES:?}"),
));
}
Ok(result) Ok(result)
} }
fn options(&self) -> TokenStream { fn options(&self) -> TokenStream {
let clone_ident = if self.has_ref_attr { let returns = &self.returns;
syn::Ident::new("no_clone", Span::call_site())
} else {
syn::Ident::new("clone", Span::call_site())
};
let backdate_ident = if self.has_no_eq_attr { let backdate_ident = if self.has_no_eq_attr {
syn::Ident::new("no_backdate", Span::call_site()) syn::Ident::new("no_backdate", Span::call_site())
@ -405,6 +417,6 @@ impl<'s> SalsaField<'s> {
syn::Ident::new("required", Span::call_site()) syn::Ident::new("required", Span::call_site())
}; };
quote!((#clone_ident, #backdate_ident, #default_ident)) quote!((#returns, #backdate_ident, #default_ident))
} }
} }

View file

@ -1,7 +1,7 @@
use proc_macro2::{Literal, Span, TokenStream}; use proc_macro2::{Literal, Span, TokenStream};
use quote::ToTokens; use quote::ToTokens;
use syn::spanned::Spanned; use syn::spanned::Spanned;
use syn::ItemFn; use syn::{Ident, ItemFn};
use crate::hygiene::Hygiene; use crate::hygiene::Hygiene;
use crate::options::Options; use crate::options::Options;
@ -26,7 +26,7 @@ pub type FnArgs = Options<TrackedFn>;
pub struct TrackedFn; pub struct TrackedFn;
impl crate::options::AllowedOptions for TrackedFn { impl crate::options::AllowedOptions for TrackedFn {
const RETURN_REF: bool = true; const RETURNS: bool = true;
const SPECIFY: bool = true; const SPECIFY: bool = true;
@ -67,6 +67,8 @@ struct ValidFn<'item> {
db_path: &'item syn::Path, db_path: &'item syn::Path,
} }
const ALLOWED_RETURN_MODES: &[&str] = &["copy", "clone", "ref", "deref", "as_ref", "as_deref"];
#[allow(non_snake_case)] #[allow(non_snake_case)]
impl Macro { impl Macro {
fn try_fn(&self, item: syn::ItemFn) -> syn::Result<TokenStream> { fn try_fn(&self, item: syn::ItemFn) -> syn::Result<TokenStream> {
@ -146,7 +148,22 @@ impl Macro {
let lru = Literal::usize_unsuffixed(self.args.lru.unwrap_or(0)); let lru = Literal::usize_unsuffixed(self.args.lru.unwrap_or(0));
let return_ref: bool = self.args.return_ref.is_some(); let return_mode = self
.args
.returns
.clone()
.unwrap_or(Ident::new("clone", Span::call_site()));
// Validate return mode
if !ALLOWED_RETURN_MODES
.iter()
.any(|mode| mode == &return_mode.to_string())
{
return Err(syn::Error::new(
return_mode.span(),
format!("Invalid return mode. Allowed modes are: {ALLOWED_RETURN_MODES:?}"),
));
}
// The path expression is responsible for emitting the primary span in the diagnostic we // The path expression is responsible for emitting the primary span in the diagnostic we
// want, so by uniformly using `output_ty.span()` we ensure that the diagnostic is emitted // want, so by uniformly using `output_ty.span()` we ensure that the diagnostic is emitted
@ -183,7 +200,7 @@ impl Macro {
values_equal: {#eq}, values_equal: {#eq},
needs_interner: #needs_interner, needs_interner: #needs_interner,
lru: #lru, lru: #lru,
return_ref: #return_ref, return_mode: #return_mode,
assert_return_type_is_update: { #assert_return_type_is_update }, assert_return_type_is_update: { #assert_return_type_is_update },
unused_names: [ unused_names: [
#zalsa, #zalsa,

View file

@ -296,13 +296,13 @@ impl Macro {
args: &FnArgs, args: &FnArgs,
db_lt: &Option<syn::Lifetime>, db_lt: &Option<syn::Lifetime>,
) -> syn::Result<()> { ) -> syn::Result<()> {
if let Some(return_ref) = &args.return_ref { if let Some(returns) = &args.returns {
if let syn::ReturnType::Type(_, t) = &mut sig.output { if let syn::ReturnType::Type(_, t) = &mut sig.output {
**t = parse_quote!(& #db_lt #t) **t = parse_quote!(& #db_lt #t)
} else { } else {
return Err(syn::Error::new_spanned( return Err(syn::Error::new_spanned(
return_ref, returns,
"return_ref attribute requires explicit return type", "returns attribute requires explicit return type",
)); ));
}; };
} }

View file

@ -28,7 +28,7 @@ type TrackedArgs = Options<TrackedStruct>;
struct TrackedStruct; struct TrackedStruct;
impl crate::options::AllowedOptions for TrackedStruct { impl crate::options::AllowedOptions for TrackedStruct {
const RETURN_REF: bool = false; const RETURNS: bool = false;
const SPECIFY: bool = false; const SPECIFY: bool = false;

View file

@ -5,7 +5,7 @@ use ordered_float::OrderedFloat;
// ANCHOR: input // ANCHOR: input
#[salsa::input(debug)] #[salsa::input(debug)]
pub struct SourceProgram { pub struct SourceProgram {
#[return_ref] #[returns(ref)]
pub text: String, pub text: String,
} }
// ANCHOR_END: input // ANCHOR_END: input
@ -13,13 +13,13 @@ pub struct SourceProgram {
// ANCHOR: interned_ids // ANCHOR: interned_ids
#[salsa::interned(debug)] #[salsa::interned(debug)]
pub struct VariableId<'db> { pub struct VariableId<'db> {
#[return_ref] #[returns(ref)]
pub text: String, pub text: String,
} }
#[salsa::interned(debug)] #[salsa::interned(debug)]
pub struct FunctionId<'db> { pub struct FunctionId<'db> {
#[return_ref] #[returns(ref)]
pub text: String, pub text: String,
} }
// ANCHOR_END: interned_ids // ANCHOR_END: interned_ids
@ -28,7 +28,7 @@ pub struct FunctionId<'db> {
#[salsa::tracked(debug)] #[salsa::tracked(debug)]
pub struct Program<'db> { pub struct Program<'db> {
#[tracked] #[tracked]
#[return_ref] #[returns(ref)]
pub statements: Vec<Statement<'db>>, pub statements: Vec<Statement<'db>>,
} }
// ANCHOR_END: program // ANCHOR_END: program
@ -93,11 +93,11 @@ pub struct Function<'db> {
name_span: Span<'db>, name_span: Span<'db>,
#[tracked] #[tracked]
#[return_ref] #[returns(ref)]
pub args: Vec<VariableId<'db>>, pub args: Vec<VariableId<'db>>,
#[tracked] #[tracked]
#[return_ref] #[returns(ref)]
pub body: Expression<'db>, pub body: Expression<'db>,
} }
// ANCHOR_END: functions // ANCHOR_END: functions

View file

@ -67,7 +67,7 @@ fn main() -> Result<()> {
#[salsa::input] #[salsa::input]
struct File { struct File {
path: PathBuf, path: PathBuf,
#[return_ref] #[returns(ref)]
contents: String, contents: String,
} }
@ -158,7 +158,7 @@ impl Diagnostic {
#[salsa::tracked] #[salsa::tracked]
struct ParsedFile<'db> { struct ParsedFile<'db> {
value: u32, value: u32,
#[return_ref] #[returns(ref)]
links: Vec<ParsedFile<'db>>, links: Vec<ParsedFile<'db>>,
} }

View file

@ -22,6 +22,7 @@ mod memo_ingredient_indices;
mod nonce; mod nonce;
#[cfg(feature = "rayon")] #[cfg(feature = "rayon")]
mod parallel; mod parallel;
mod return_mode;
mod revision; mod revision;
mod runtime; mod runtime;
mod salsa_struct; mod salsa_struct;
@ -49,6 +50,8 @@ pub use self::event::{Event, EventKind};
pub use self::id::Id; pub use self::id::Id;
pub use self::input::setter::Setter; pub use self::input::setter::Setter;
pub use self::key::DatabaseKeyIndex; pub use self::key::DatabaseKeyIndex;
pub use self::return_mode::SalsaAsDeref;
pub use self::return_mode::SalsaAsRef;
pub use self::revision::Revision; pub use self::revision::Revision;
pub use self::runtime::Runtime; pub use self::runtime::Runtime;
pub use self::storage::{Storage, StorageHandle}; pub use self::storage::{Storage, StorageHandle};
@ -71,9 +74,9 @@ pub mod plumbing {
pub use std::option::Option::{self, None, Some}; pub use std::option::Option::{self, None, Some};
pub use salsa_macro_rules::{ pub use salsa_macro_rules::{
macro_if, maybe_backdate, maybe_clone, maybe_cloned_ty, maybe_default, maybe_default_tt, macro_if, maybe_backdate, maybe_default, maybe_default_tt, return_mode_expression,
setup_accumulator_impl, setup_input_struct, setup_interned_struct, setup_method_body, return_mode_ty, setup_accumulator_impl, setup_input_struct, setup_interned_struct,
setup_tracked_fn, setup_tracked_struct, unexpected_cycle_initial, setup_method_body, setup_tracked_fn, setup_tracked_struct, unexpected_cycle_initial,
unexpected_cycle_recovery, unexpected_cycle_recovery,
}; };

69
src/return_mode.rs Normal file
View file

@ -0,0 +1,69 @@
//! User-implementable salsa traits for refining the return type via `returns(as_ref)` and `returns(as_deref)`.
use std::ops::Deref;
/// Used to determine the return type and value for tracked fields and functions annotated with `returns(as_ref)`.
pub trait SalsaAsRef {
// The type returned by tracked fields and functions annotated with `returns(as_ref)`.
type AsRef<'a>
where
Self: 'a;
// The value returned by tracked fields and functions annotated with `returns(as_ref)`.
fn as_ref(&self) -> Self::AsRef<'_>;
}
impl<T> SalsaAsRef for Option<T> {
type AsRef<'a>
= Option<&'a T>
where
Self: 'a;
fn as_ref(&self) -> Self::AsRef<'_> {
self.as_ref()
}
}
impl<T, E> SalsaAsRef for Result<T, E> {
type AsRef<'a>
= Result<&'a T, &'a E>
where
Self: 'a;
fn as_ref(&self) -> Self::AsRef<'_> {
self.as_ref()
}
}
/// Used to determine the return type and value for tracked fields and functions annotated with `returns(as_deref)`.
pub trait SalsaAsDeref {
// The type returned by tracked fields and functions annotated with `returns(as_deref)`.
type AsDeref<'a>
where
Self: 'a;
// The value returned by tracked fields and functions annotated with `returns(as_deref)`.
fn as_deref(&self) -> Self::AsDeref<'_>;
}
impl<T: Deref> SalsaAsDeref for Option<T> {
type AsDeref<'a>
= Option<&'a T::Target>
where
Self: 'a;
fn as_deref(&self) -> Self::AsDeref<'_> {
self.as_deref()
}
}
impl<T: Deref, E> SalsaAsDeref for Result<T, E> {
type AsDeref<'a>
= Result<&'a T::Target, &'a E>
where
Self: 'a;
fn as_deref(&self) -> Self::AsDeref<'_> {
self.as_deref()
}
}

View file

@ -37,7 +37,7 @@ fn compute(db: &dyn LogDatabase, input: List) -> u32 {
result result
} }
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
fn accumulated(db: &dyn LogDatabase, input: List) -> Vec<u32> { fn accumulated(db: &dyn LogDatabase, input: List) -> Vec<u32> {
db.push_log(format!("accumulated({input:?})")); db.push_log(format!("accumulated({input:?})"));
compute::accumulated::<Integers>(db, input) compute::accumulated::<Integers>(db, input)

View file

@ -1,4 +1,4 @@
#[salsa::accumulator(return_ref)] #[salsa::accumulator(returns(ref))]
struct AccWithRetRef(u32); struct AccWithRetRef(u32);
#[salsa::accumulator(specify)] #[salsa::accumulator(specify)]

View file

@ -1,8 +1,8 @@
error: `return_ref` option not allowed here error: `returns` option not allowed here
--> tests/compile-fail/accumulator_incompatibles.rs:1:22 --> tests/compile-fail/accumulator_incompatibles.rs:1:22
| |
1 | #[salsa::accumulator(return_ref)] 1 | #[salsa::accumulator(returns(ref))]
| ^^^^^^^^^^ | ^^^^^^^
error: `specify` option not allowed here error: `specify` option not allowed here
--> tests/compile-fail/accumulator_incompatibles.rs:4:22 --> tests/compile-fail/accumulator_incompatibles.rs:4:22

View file

@ -1,4 +1,4 @@
#[salsa::input(return_ref)] #[salsa::input(returns(ref))]
struct InputWithRetRef(u32); struct InputWithRetRef(u32);
#[salsa::input(specify)] #[salsa::input(specify)]

View file

@ -1,8 +1,8 @@
error: `return_ref` option not allowed here error: `returns` option not allowed here
--> tests/compile-fail/input_struct_incompatibles.rs:1:16 --> tests/compile-fail/input_struct_incompatibles.rs:1:16
| |
1 | #[salsa::input(return_ref)] 1 | #[salsa::input(returns(ref))]
| ^^^^^^^^^^ | ^^^^^^^
error: `specify` option not allowed here error: `specify` option not allowed here
--> tests/compile-fail/input_struct_incompatibles.rs:4:16 --> tests/compile-fail/input_struct_incompatibles.rs:4:16

View file

@ -1,4 +1,4 @@
#[salsa::interned(return_ref)] #[salsa::interned(returns(ref))]
struct InternedWithRetRef { struct InternedWithRetRef {
field: u32, field: u32,
} }

View file

@ -1,8 +1,8 @@
error: `return_ref` option not allowed here error: `returns` option not allowed here
--> tests/compile-fail/interned_struct_incompatibles.rs:1:19 --> tests/compile-fail/interned_struct_incompatibles.rs:1:19
| |
1 | #[salsa::interned(return_ref)] 1 | #[salsa::interned(returns(ref))]
| ^^^^^^^^^^ | ^^^^^^^
error: `specify` option not allowed here error: `specify` option not allowed here
--> tests/compile-fail/interned_struct_incompatibles.rs:6:19 --> tests/compile-fail/interned_struct_incompatibles.rs:6:19

View file

@ -0,0 +1,20 @@
use salsa::Database as Db;
#[salsa::input]
struct MyInput {
#[returns(clone)]
text: String,
}
#[salsa::tracked(returns(not_a_return_mode))]
fn tracked_fn_invalid_return_mode(db: &dyn Db, input: MyInput) -> String {
input.text(db)
}
#[salsa::input]
struct MyInvalidInput {
#[returns(not_a_return_mode)]
text: String,
}
fn main() { }

View file

@ -0,0 +1,17 @@
error: Invalid return mode. Allowed modes are: ["copy", "clone", "ref", "deref", "as_ref", "as_deref"]
--> tests/compile-fail/invalid_return_mode.rs:9:26
|
9 | #[salsa::tracked(returns(not_a_return_mode))]
| ^^^^^^^^^^^^^^^^^
error: Invalid return mode. Allowed modes are: ["copy", "clone", "ref", "deref", "as_ref", "as_deref"]
--> tests/compile-fail/invalid_return_mode.rs:16:15
|
16 | #[returns(not_a_return_mode)]
| ^^^^^^^^^^^^^^^^^
error: cannot find attribute `returns` in this scope
--> tests/compile-fail/invalid_return_mode.rs:16:7
|
16 | #[returns(not_a_return_mode)]
| ^^^^^^^

View file

@ -2,7 +2,7 @@ use salsa::Database as Db;
#[salsa::input] #[salsa::input]
struct MyInput { struct MyInput {
#[return_ref] #[returns(ref)]
text: String, text: String,
} }

View file

@ -3,7 +3,7 @@ struct MyTracked<'db> {
field: u32, field: u32,
} }
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
impl<'db> std::default::Default for MyTracked<'db> { impl<'db> std::default::Default for MyTracked<'db> {
fn default() -> Self {} fn default() -> Self {}
} }

View file

@ -1,8 +1,8 @@
error: unexpected token error: unexpected token
--> tests/compile-fail/tracked_impl_incompatibles.rs:6:18 --> tests/compile-fail/tracked_impl_incompatibles.rs:6:18
| |
6 | #[salsa::tracked(return_ref)] 6 | #[salsa::tracked(returns(ref))]
| ^^^^^^^^^^ | ^^^^^^^
error: unexpected token error: unexpected token
--> tests/compile-fail/tracked_impl_incompatibles.rs:11:18 --> tests/compile-fail/tracked_impl_incompatibles.rs:11:18

View file

@ -1,4 +1,4 @@
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
struct TrackedWithRetRef { struct TrackedWithRetRef {
field: u32, field: u32,
} }

View file

@ -1,8 +1,8 @@
error: `return_ref` option not allowed here error: `returns` option not allowed here
--> tests/compile-fail/tracked_struct_incompatibles.rs:1:18 --> tests/compile-fail/tracked_struct_incompatibles.rs:1:18
| |
1 | #[salsa::tracked(return_ref)] 1 | #[salsa::tracked(returns(ref))]
| ^^^^^^^^^^ | ^^^^^^^
error: `specify` option not allowed here error: `specify` option not allowed here
--> tests/compile-fail/tracked_struct_incompatibles.rs:6:18 --> tests/compile-fail/tracked_struct_incompatibles.rs:6:18

View file

@ -32,7 +32,7 @@ impl Value {
/// `max_iterate`, `min_panic`, `max_panic`) for testing cycle behaviors. /// `max_iterate`, `min_panic`, `max_panic`) for testing cycle behaviors.
#[salsa::input] #[salsa::input]
struct Inputs { struct Inputs {
#[return_ref] #[returns(ref)]
inputs: Vec<Input>, inputs: Vec<Input>,
} }

View file

@ -12,7 +12,6 @@ struct Input {
#[salsa::interned(debug)] #[salsa::interned(debug)]
struct Output<'db> { struct Output<'db> {
#[return_ref]
value: u32, value: u32,
} }
@ -170,7 +169,7 @@ fn nested_cycle_fewer_dependencies_in_first_iteration() {
} }
} }
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
fn index<'db>(db: &'db dyn salsa::Database, input: Input) -> Index<'db> { fn index<'db>(db: &'db dyn salsa::Database, input: Input) -> Index<'db> {
Index { Index {
scope: Scope::new(db, input.value(db) * 2), scope: Scope::new(db, input.value(db) * 2),

View file

@ -30,10 +30,10 @@ struct Edge {
#[salsa::tracked(debug)] #[salsa::tracked(debug)]
struct Node<'db> { struct Node<'db> {
#[return_ref] #[returns(ref)]
name: String, name: String,
#[return_ref] #[returns(deref)]
#[tracked] #[tracked]
edges: Vec<Edge>, edges: Vec<Edge>,
@ -45,7 +45,7 @@ struct GraphInput {
simple: bool, simple: bool,
} }
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
fn create_graph(db: &dyn salsa::Database, input: GraphInput) -> Graph<'_> { fn create_graph(db: &dyn salsa::Database, input: GraphInput) -> Graph<'_> {
if input.simple(db) { if input.simple(db) {
let a = Node::new(db, "a".to_string(), vec![], input); let a = Node::new(db, "a".to_string(), vec![], input);

View file

@ -17,7 +17,7 @@ struct MyTracked<'db> {
identifier: u32, identifier: u32,
#[tracked] #[tracked]
#[return_ref] #[returns(ref)]
field: Bomb, field: Bomb,
} }

View file

@ -34,7 +34,7 @@ impl MyInput {
#[salsa::interned(constructor = from_string)] #[salsa::interned(constructor = from_string)]
struct MyInterned<'db> { struct MyInterned<'db> {
#[get(text)] #[get(text)]
#[return_ref] #[returns(ref)]
field: String, field: String,
} }

172
tests/return_mode.rs Normal file
View file

@ -0,0 +1,172 @@
use salsa::Database;
#[salsa::input]
struct DefaultInput {
text: String,
}
#[salsa::tracked]
fn default_fn(db: &dyn Database, input: DefaultInput) -> String {
let input: String = input.text(db);
input
}
#[test]
fn default_test() {
salsa::DatabaseImpl::new().attach(|db| {
let input = DefaultInput::new(db, "Test".into());
let x: String = default_fn(db, input);
expect_test::expect![[r#"
"Test"
"#]]
.assert_debug_eq(&x);
})
}
#[salsa::input]
struct CopyInput {
#[returns(copy)]
text: &'static str,
}
#[salsa::tracked(returns(copy))]
fn copy_fn(db: &dyn Database, input: CopyInput) -> &'static str {
let input: &'static str = input.text(db);
input
}
#[test]
fn copy_test() {
salsa::DatabaseImpl::new().attach(|db| {
let input = CopyInput::new(db, "Test");
let x: &str = copy_fn(db, input);
expect_test::expect![[r#"
"Test"
"#]]
.assert_debug_eq(&x);
})
}
#[salsa::input]
struct CloneInput {
#[returns(clone)]
text: String,
}
#[salsa::tracked(returns(clone))]
fn clone_fn(db: &dyn Database, input: CloneInput) -> String {
let input: String = input.text(db);
input
}
#[test]
fn clone_test() {
salsa::DatabaseImpl::new().attach(|db| {
let input = CloneInput::new(db, "Test".into());
let x: String = clone_fn(db, input);
expect_test::expect![[r#"
"Test"
"#]]
.assert_debug_eq(&x);
})
}
#[salsa::input]
struct RefInput {
#[returns(ref)]
text: String,
}
#[salsa::tracked(returns(ref))]
fn ref_fn(db: &dyn Database, input: RefInput) -> String {
let input: &String = input.text(db);
input.to_owned()
}
#[test]
fn ref_test() {
salsa::DatabaseImpl::new().attach(|db| {
let input = RefInput::new(db, "Test".into());
let x: &String = ref_fn(db, input);
expect_test::expect![[r#"
"Test"
"#]]
.assert_debug_eq(&x);
})
}
#[salsa::input]
struct DerefInput {
#[returns(deref)]
text: String,
}
#[salsa::tracked(returns(deref))]
fn deref_fn(db: &dyn Database, input: DerefInput) -> String {
let input: &str = input.text(db);
input.to_owned()
}
#[test]
fn deref_test() {
salsa::DatabaseImpl::new().attach(|db| {
let input = DerefInput::new(db, "Test".into());
let x: &str = deref_fn(db, input);
expect_test::expect![[r#"
"Test"
"#]]
.assert_debug_eq(&x);
})
}
#[salsa::input]
struct AsRefInput {
#[returns(as_ref)]
text: Option<String>,
}
#[salsa::tracked(returns(as_ref))]
fn as_ref_fn(db: &dyn Database, input: AsRefInput) -> Option<String> {
let input: Option<&String> = input.text(db);
input.cloned()
}
#[test]
fn as_ref_test() {
salsa::DatabaseImpl::new().attach(|db| {
let input = AsRefInput::new(db, Some("Test".into()));
let x: Option<&String> = as_ref_fn(db, input);
expect_test::expect![[r#"
Some(
"Test",
)
"#]]
.assert_debug_eq(&x);
})
}
#[salsa::input]
struct AsDerefInput {
#[returns(as_deref)]
text: Option<String>,
}
#[salsa::tracked(returns(as_deref))]
fn as_deref_fn(db: &dyn Database, input: AsDerefInput) -> Option<String> {
let input: Option<&str> = input.text(db);
input.map(|s| s.to_owned())
}
#[test]
fn as_deref_test() {
salsa::DatabaseImpl::new().attach(|db| {
let input = AsDerefInput::new(db, Some("Test".into()));
let x: Option<&str> = as_deref_fn(db, input);
expect_test::expect![[r#"
Some(
"Test",
)
"#]]
.assert_debug_eq(&x);
})
}

View file

@ -5,7 +5,7 @@ struct Input {
number: usize, number: usize,
} }
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
fn test(db: &dyn salsa::Database, input: Input) -> Vec<String> { fn test(db: &dyn salsa::Database, input: Input) -> Vec<String> {
(0..input.number(db)).map(|i| format!("test {i}")).collect() (0..input.number(db)).map(|i| format!("test {i}")).collect()
} }

View file

@ -23,7 +23,7 @@ impl MyInput {
self.field(db) * 2 self.field(db) * 2
} }
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
fn tracked_fn_ref(self, db: &dyn salsa::Database) -> u32 { fn tracked_fn_ref(self, db: &dyn salsa::Database) -> u32 {
self.field(db) * 3 self.field(db) * 3
} }

View file

@ -7,7 +7,7 @@ struct Input {
#[salsa::tracked] #[salsa::tracked]
impl Input { impl Input {
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
fn test(self, db: &dyn salsa::Database) -> Vec<String> { fn test(self, db: &dyn salsa::Database) -> Vec<String> {
(0..self.number(db)).map(|i| format!("test {i}")).collect() (0..self.number(db)).map(|i| format!("test {i}")).collect()
} }

View file

@ -23,7 +23,7 @@ pub struct SourceTree<'db> {
#[salsa::tracked] #[salsa::tracked]
impl<'db1> SourceTree<'db1> { impl<'db1> SourceTree<'db1> {
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
pub fn inherent_item_name(self, db: &'db1 dyn Database) -> String { pub fn inherent_item_name(self, db: &'db1 dyn Database) -> String {
self.name(db) self.name(db)
} }
@ -35,7 +35,7 @@ trait ItemName<'db1> {
#[salsa::tracked] #[salsa::tracked]
impl<'db1> ItemName<'db1> for SourceTree<'db1> { impl<'db1> ItemName<'db1> for SourceTree<'db1> {
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
fn trait_item_name(self, db: &'db1 dyn Database) -> String { fn trait_item_name(self, db: &'db1 dyn Database) -> String {
self.name(db) self.name(db)
} }

View file

@ -11,7 +11,7 @@ trait Trait {
#[salsa::tracked] #[salsa::tracked]
impl Trait for Input { impl Trait for Input {
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
fn test(self, db: &dyn salsa::Database) -> Vec<String> { fn test(self, db: &dyn salsa::Database) -> Vec<String> {
(0..self.number(db)).map(|i| format!("test {i}")).collect() (0..self.number(db)).map(|i| format!("test {i}")).collect()
} }

View file

@ -3,6 +3,6 @@ enum Token {}
#[salsa::tracked] #[salsa::tracked]
struct TokenTree<'db> { struct TokenTree<'db> {
#[return_ref] #[returns(ref)]
tokens: Vec<Token>, tokens: Vec<Token>,
} }

View file

@ -9,13 +9,13 @@ pub struct SourceTree<'db> {}
#[salsa::tracked] #[salsa::tracked]
impl<'db> SourceTree<'db> { impl<'db> SourceTree<'db> {
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
pub fn all_items(self, _db: &'db dyn Db) -> Vec<Item> { pub fn all_items(self, _db: &'db dyn Db) -> Vec<Item> {
todo!() todo!()
} }
} }
#[salsa::tracked(return_ref)] #[salsa::tracked(returns(ref))]
fn use_tree<'db>(_db: &'db dyn Db, _tree: SourceTree<'db>) {} fn use_tree<'db>(_db: &'db dyn Db, _tree: SourceTree<'db>) {}
#[allow(unused)] #[allow(unused)]