mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-08-04 11:00:05 +00:00
fix: Replace SelfTy
with actual type in tracked methods
This commit is contained in:
parent
af2ec49d80
commit
ad1f84d80f
3 changed files with 184 additions and 4 deletions
|
@ -1,8 +1,10 @@
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use proc_macro2::TokenStream;
|
use proc_macro2::TokenStream;
|
||||||
use quote::ToTokens;
|
use quote::ToTokens;
|
||||||
use syn::parse::Nothing;
|
use syn::{parse::Nothing, visit_mut::VisitMut};
|
||||||
|
|
||||||
use crate::{hygiene::Hygiene, tracked_fn::FnArgs};
|
use crate::{hygiene::Hygiene, tracked_fn::FnArgs, xform::ChangeSelfPath};
|
||||||
|
|
||||||
pub(crate) fn tracked_impl(
|
pub(crate) fn tracked_impl(
|
||||||
args: proc_macro::TokenStream,
|
args: proc_macro::TokenStream,
|
||||||
|
@ -32,8 +34,19 @@ struct MethodArguments<'syn> {
|
||||||
impl Macro {
|
impl Macro {
|
||||||
fn try_generate(&self, mut impl_item: syn::ItemImpl) -> syn::Result<TokenStream> {
|
fn try_generate(&self, mut impl_item: syn::ItemImpl) -> syn::Result<TokenStream> {
|
||||||
let mut member_items = std::mem::take(&mut impl_item.items);
|
let mut member_items = std::mem::take(&mut impl_item.items);
|
||||||
|
let member_idents: HashSet<_> = member_items
|
||||||
|
.iter()
|
||||||
|
.filter_map(|item| match item {
|
||||||
|
syn::ImplItem::Const(it) => Some(it.ident.clone()),
|
||||||
|
syn::ImplItem::Fn(it) => Some(it.sig.ident.clone()),
|
||||||
|
syn::ImplItem::Type(it) => Some(it.ident.clone()),
|
||||||
|
syn::ImplItem::Macro(_) => None,
|
||||||
|
syn::ImplItem::Verbatim(_) => None,
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
for member_item in &mut member_items {
|
for member_item in &mut member_items {
|
||||||
self.modify_member(&impl_item, member_item)?;
|
self.modify_member(&impl_item, member_item, &member_idents)?;
|
||||||
}
|
}
|
||||||
impl_item.items = member_items;
|
impl_item.items = member_items;
|
||||||
Ok(crate::debug::dump_tokens(
|
Ok(crate::debug::dump_tokens(
|
||||||
|
@ -47,6 +60,7 @@ impl Macro {
|
||||||
&self,
|
&self,
|
||||||
impl_item: &syn::ItemImpl,
|
impl_item: &syn::ItemImpl,
|
||||||
member_item: &mut syn::ImplItem,
|
member_item: &mut syn::ImplItem,
|
||||||
|
member_idents: &HashSet<syn::Ident>,
|
||||||
) -> syn::Result<()> {
|
) -> syn::Result<()> {
|
||||||
let syn::ImplItem::Fn(fn_item) = member_item else {
|
let syn::ImplItem::Fn(fn_item) = member_item else {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
|
@ -59,6 +73,13 @@ impl Macro {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let trait_ = match &impl_item.trait_ {
|
||||||
|
Some((None, path, _)) => Some((path, member_idents)),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
let mut change = ChangeSelfPath::new(self_ty, trait_);
|
||||||
|
change.visit_impl_item_fn_mut(fn_item);
|
||||||
|
|
||||||
let salsa_tracked_attr = fn_item.attrs.remove(tracked_attr_index);
|
let salsa_tracked_attr = fn_item.attrs.remove(tracked_attr_index);
|
||||||
let args: FnArgs = match &salsa_tracked_attr.meta {
|
let args: FnArgs = match &salsa_tracked_attr.meta {
|
||||||
syn::Meta::Path(..) => Default::default(),
|
syn::Meta::Path(..) => Default::default(),
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
use syn::visit_mut::VisitMut;
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use quote::ToTokens;
|
||||||
|
use syn::{punctuated::Punctuated, spanned::Spanned, visit_mut::VisitMut};
|
||||||
|
|
||||||
pub(crate) struct ChangeLt<'a> {
|
pub(crate) struct ChangeLt<'a> {
|
||||||
from: Option<&'a str>,
|
from: Option<&'a str>,
|
||||||
|
@ -12,6 +15,7 @@ impl ChangeLt<'_> {
|
||||||
to: db_lt.ident.to_string(),
|
to: db_lt.ident.to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn in_type(mut self, ty: &syn::Type) -> syn::Type {
|
pub fn in_type(mut self, ty: &syn::Type) -> syn::Type {
|
||||||
let mut ty = ty.clone();
|
let mut ty = ty.clone();
|
||||||
self.visit_type_mut(&mut ty);
|
self.visit_type_mut(&mut ty);
|
||||||
|
@ -26,3 +30,114 @@ impl syn::visit_mut::VisitMut for ChangeLt<'_> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) struct ChangeSelfPath<'a> {
|
||||||
|
self_ty: &'a syn::Type,
|
||||||
|
trait_: Option<(&'a syn::Path, &'a HashSet<syn::Ident>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChangeSelfPath<'_> {
|
||||||
|
pub fn new<'a>(
|
||||||
|
self_ty: &'a syn::Type,
|
||||||
|
trait_: Option<(&'a syn::Path, &'a HashSet<syn::Ident>)>,
|
||||||
|
) -> ChangeSelfPath<'a> {
|
||||||
|
ChangeSelfPath { self_ty, trait_ }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl syn::visit_mut::VisitMut for ChangeSelfPath<'_> {
|
||||||
|
fn visit_type_mut(&mut self, i: &mut syn::Type) {
|
||||||
|
if let syn::Type::Path(syn::TypePath { qself: None, path }) = i {
|
||||||
|
if path.segments.len() == 1 && path.segments.first().is_some_and(|s| s.ident == "Self")
|
||||||
|
{
|
||||||
|
let span = path.segments.first().unwrap().span();
|
||||||
|
*i = respan(self.self_ty, span);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
syn::visit_mut::visit_type_mut(self, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_type_path_mut(&mut self, i: &mut syn::TypePath) {
|
||||||
|
// `<Self as ..>` cases are handled in `visit_type_mut`
|
||||||
|
if i.qself.is_some() {
|
||||||
|
syn::visit_mut::visit_type_path_mut(self, i);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// A single path `Self` case is handled in `visit_type_mut`
|
||||||
|
if i.path.segments.first().is_some_and(|s| s.ident == "Self") && i.path.segments.len() > 1 {
|
||||||
|
let span = i.path.segments.first().unwrap().span();
|
||||||
|
let ty = Box::new(respan::<syn::Type>(self.self_ty, span));
|
||||||
|
let lt_token = syn::Token;
|
||||||
|
let gt_token = syn::Token;
|
||||||
|
match self.trait_ {
|
||||||
|
// If the next segment's ident is a trait member, replace `Self::` with
|
||||||
|
// `<ActualTy as Trait>::`
|
||||||
|
Some((trait_, member_idents))
|
||||||
|
if member_idents.contains(&i.path.segments.iter().nth(1).unwrap().ident) =>
|
||||||
|
{
|
||||||
|
let qself = syn::QSelf {
|
||||||
|
lt_token,
|
||||||
|
ty,
|
||||||
|
position: trait_.segments.len(),
|
||||||
|
as_token: Some(syn::Token),
|
||||||
|
gt_token,
|
||||||
|
};
|
||||||
|
i.qself = Some(qself);
|
||||||
|
i.path.segments = Punctuated::from_iter(
|
||||||
|
trait_
|
||||||
|
.segments
|
||||||
|
.iter()
|
||||||
|
.chain(i.path.segments.iter().skip(1))
|
||||||
|
.cloned(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// Replace `Self::` with `<ActualTy>::` otherwise
|
||||||
|
_ => {
|
||||||
|
let qself = syn::QSelf {
|
||||||
|
lt_token,
|
||||||
|
ty,
|
||||||
|
position: 0,
|
||||||
|
as_token: None,
|
||||||
|
gt_token,
|
||||||
|
};
|
||||||
|
i.qself = Some(qself);
|
||||||
|
i.path.segments =
|
||||||
|
Punctuated::from_iter(i.path.segments.iter().skip(1).cloned());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
syn::visit_mut::visit_type_path_mut(self, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn respan<T>(t: &T, span: proc_macro2::Span) -> T
|
||||||
|
where
|
||||||
|
T: ToTokens + Spanned + syn::parse::Parse,
|
||||||
|
{
|
||||||
|
let tokens = t.to_token_stream();
|
||||||
|
let respanned = respan_tokenstream(tokens, span);
|
||||||
|
syn::parse2(respanned).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn respan_tokenstream(
|
||||||
|
stream: proc_macro2::TokenStream,
|
||||||
|
span: proc_macro2::Span,
|
||||||
|
) -> proc_macro2::TokenStream {
|
||||||
|
stream
|
||||||
|
.into_iter()
|
||||||
|
.map(|token| respan_token(token, span))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn respan_token(
|
||||||
|
mut token: proc_macro2::TokenTree,
|
||||||
|
span: proc_macro2::Span,
|
||||||
|
) -> proc_macro2::TokenTree {
|
||||||
|
if let proc_macro2::TokenTree::Group(g) = &mut token {
|
||||||
|
*g = proc_macro2::Group::new(g.delimiter(), respan_tokenstream(g.stream(), span));
|
||||||
|
}
|
||||||
|
token.set_span(span);
|
||||||
|
token
|
||||||
|
}
|
||||||
|
|
44
tests/tracked_method_with_self_ty.rs
Normal file
44
tests/tracked_method_with_self_ty.rs
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
//! Test that a `tracked` fn with `Self` in its signature or body on a `salsa::input`
|
||||||
|
//! compiles and executes successfully.
|
||||||
|
#![allow(warnings)]
|
||||||
|
|
||||||
|
trait TrackedTrait {
|
||||||
|
type Type;
|
||||||
|
|
||||||
|
fn tracked_trait_fn(self, db: &dyn salsa::Database, ty: Self::Type) -> Self::Type;
|
||||||
|
|
||||||
|
fn untracked_trait_fn();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::input]
|
||||||
|
struct MyInput {
|
||||||
|
field: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
impl MyInput {
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn tracked_fn(self, db: &dyn salsa::Database, other: Self) -> u32 {
|
||||||
|
self.field(db) + other.field(db)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
impl TrackedTrait for MyInput {
|
||||||
|
type Type = u32;
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn tracked_trait_fn(self, db: &dyn salsa::Database, ty: Self::Type) -> Self::Type {
|
||||||
|
Self::untracked_trait_fn();
|
||||||
|
Self::tracked_fn(self, db, self) + ty
|
||||||
|
}
|
||||||
|
|
||||||
|
fn untracked_trait_fn() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn execute() {
|
||||||
|
let mut db = salsa::DatabaseImpl::new();
|
||||||
|
let object = MyInput::new(&mut db, 10);
|
||||||
|
assert_eq!(object.tracked_trait_fn(&db, 1), 21);
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue