diff --git a/helper_crates/vtable/macro/macro.rs b/helper_crates/vtable/macro/macro.rs index 27be17f56..4997dd17b 100644 --- a/helper_crates/vtable/macro/macro.rs +++ b/helper_crates/vtable/macro/macro.rs @@ -53,7 +53,6 @@ pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream { let trait_name = Ident::new(&vtable_name[..vtable_name.len() - 6], input.ident.span()); let to_name = quote::format_ident!("{}TO", trait_name); let impl_name = quote::format_ident!("{}Impl", trait_name); - let type_name = quote::format_ident!("{}Type", trait_name); let module_name = quote::format_ident!("{}_vtable_mod", trait_name); let ref_name = quote::format_ident!("{}Ref", trait_name); let refmut_name = quote::format_ident!("{}RefMut", trait_name); @@ -93,7 +92,33 @@ pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream { field.vis = Visibility::Public(VisPublic { pub_token: Default::default() }); let ident = field.ident.as_ref().unwrap(); - if let Type::BareFn(f) = &mut field.ty { + let mut some = None; + + let func_ty = if let Type::BareFn(f) = &mut field.ty { + Some(f) + } else if let Type::Path(pat) = &mut field.ty { + pat.path.segments.last_mut().and_then(|seg| { + if seg.ident == "Option" { + some = Some(quote!(Some)); + if let PathArguments::AngleBracketed(args) = &mut seg.arguments { + if let Some(GenericArgument::Type(Type::BareFn(f))) = args.args.first_mut() + { + Some(f) + } else { + None + } + } else { + None + } + } else { + None + } + }) + } else { + None + }; + + if let Some(f) = func_ty { let mut sig = Signature { constness: None, asyncness: None, @@ -156,7 +181,7 @@ pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream { .to_compile_error() .into(); } - call_code = Some(quote!(self.vtable.as_ptr(),)); + call_code = Some(quote!(vtable as _,)); continue; } } @@ -269,7 +294,14 @@ pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream { block: parse2(if has_self { quote!({ // Safety: this rely on the vtable being valid, and the ptr being a valid instance for this vtable - unsafe { (self.vtable.as_ref().#ident)(#call_code) } + unsafe { + let vtable = self.vtable.as_ref(); + if let #some(func) = vtable.#ident { + func (#call_code) + } else { + panic!("Called a not-implemented method") + } + } }) } else { // This should never happen: nobody should be able to access the Trait Object directly. @@ -295,15 +327,16 @@ pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream { defaultness: None, sig, block: parse2(quote!({ + let vtable = self; // Safety: this rely on the vtable being valid, and the ptr being a valid instance for this vtable - unsafe { (self.vtable.as_ref().#ident)(#call_code) } + unsafe { (self.#ident)(#call_code) } })) .unwrap(), }); vtable_ctor.push(quote!(#ident: { + #[allow(unused_parens)] #sig_extern { - #[allow(unused_parens)] // This is safe since the self must be a instance of our type unsafe { #[allow(unused)] @@ -311,7 +344,7 @@ pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream { #wrap_trait_call(T::#ident(#self_call #forward_code)) } } - #ident:: + #some(#ident::) },)); } else { vtable_ctor.push(quote!(#ident: { @@ -344,16 +377,6 @@ pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream { })); vtable_ctor.push(quote!(#ident: T::#ident,)); - - generated_type_assoc_fn.push( - parse2(quote! { - pub fn #ident(&self) -> #ty { - // Safety: this rely on the vtable being valid, and the ptr being a valid instance for this vtable - unsafe { self.vtable.as_ref().#ident } - } - }) - .unwrap(), - ); }; } @@ -381,6 +404,7 @@ pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream { #(#vtable_ctor)* } } + #(#generated_type_assoc_fn)* } #generated_trait @@ -401,24 +425,10 @@ pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream { } impl #trait_name for #to_name { #(#generated_to_fn_trait)* } - #[repr(transparent)] - /// Safe wrapper around a VTable. - pub struct #type_name { - vtable: core::ptr::NonNull<#vtable_name> - } - impl #type_name { - pub const unsafe fn from_raw(vtable: core::ptr::NonNull<#vtable_name>) -> Self { - Self { vtable } - } - #(#generated_type_assoc_fn)* - } - unsafe impl core::marker::Sync for #type_name {} - unsafe impl VTableMeta for #vtable_name { type Trait = dyn #trait_name; type VTable = #vtable_name; type TraitObject = #to_name; - type Type = #type_name; #[inline] unsafe fn map_to(from: &Self::TraitObject) -> &Self::Trait { from } #[inline] @@ -426,12 +436,10 @@ pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream { #[inline] unsafe fn get_ptr(from: &Self::TraitObject) -> core::ptr::NonNull { from.ptr.cast() } #[inline] - unsafe fn get_vtable(from: &Self::TraitObject) -> core::ptr::NonNull { from.vtable } + unsafe fn get_vtable(from: &Self::TraitObject) -> &Self::VTable { from.vtable.as_ref() } #[inline] unsafe fn from_raw(vtable: core::ptr::NonNull, ptr: core::ptr::NonNull) -> Self::TraitObject { #to_name { vtable, ptr : ptr.cast() } } - #[inline] - unsafe fn get_type(from: &Self::TraitObject) -> Self::Type { #type_name::from_raw(from.vtable) } } #drop_impl @@ -445,13 +453,8 @@ pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream { ($ty:ty) => { { type T = $ty; - static VTABLE : #vtable_name = #vtable_name { + #vtable_name { #(#vtable_ctor)* - }; - unsafe { - <#vtable_name as ::vtable::VTableMeta>::Type::from_raw( - core::ptr::NonNull::new_unchecked(&VTABLE as *const _ as *mut #vtable_name) - ) } } } diff --git a/helper_crates/vtable/src/lib.rs b/helper_crates/vtable/src/lib.rs index ccc0595ef..74fcd0e22 100644 --- a/helper_crates/vtable/src/lib.rs +++ b/helper_crates/vtable/src/lib.rs @@ -9,9 +9,6 @@ pub unsafe trait VTableMeta { /// that's the vtable struct `HelloVTable` type VTable; - /// That's the safe wrapper around a vtable pointer (`HelloType`) - type Type; - /// That's the trait object that implements the trait. /// NOTE: the size must be 2*size_of type TraitObject: Copy; @@ -28,14 +25,12 @@ pub unsafe trait VTableMeta { /// Return a raw pointer to the inside of the impl unsafe fn get_ptr(from: &Self::TraitObject) -> NonNull; - /// return a raw pointer to the vtable - unsafe fn get_vtable(from: &Self::TraitObject) -> NonNull; /// Create a trait object from its raw parts unsafe fn from_raw(vtable: NonNull, ptr: NonNull) -> Self::TraitObject; - /// return a safe pointer around the vtable - unsafe fn get_type(from: &Self::TraitObject) -> Self::Type; + /// return a reference to the vtable + unsafe fn get_vtable(from: &Self::TraitObject) -> &Self::VTable; } @@ -87,14 +82,11 @@ impl VBox { pub unsafe fn get_ptr(x: &Self) -> NonNull { T::get_ptr(&x.inner) } - pub unsafe fn get_vtable(x: &Self) -> NonNull { - T::get_vtable(&x.inner) - } pub unsafe fn from_raw(vtable: NonNull, ptr: NonNull) -> Self { Self {inner : T::from_raw(vtable, ptr)} } - pub fn get_type(&self) -> T::Type { - unsafe { T::get_type(&self.inner) } + pub fn get_vtable(&self) -> &T::VTable { + unsafe { T::get_vtable(&self.inner) } } } @@ -150,14 +142,11 @@ impl<'a, T: ?Sized + VTableMeta> VRef<'a, T> { pub unsafe fn get_ptr(x: &Self) -> NonNull { T::get_ptr(&x.inner) } - pub unsafe fn get_vtable(x: &Self) -> NonNull { - T::get_vtable(&x.inner) - } pub unsafe fn from_raw(vtable: NonNull, ptr: NonNull) -> Self { Self {inner : T::from_raw(vtable, ptr), _phantom: PhantomData } } - pub fn get_type(&self) -> T::Type { - unsafe { T::get_type(&self.inner) } + pub fn get_vtable(&self) -> &T::VTable { + unsafe { T::get_vtable(&self.inner) } } } @@ -191,12 +180,6 @@ impl<'a, T: ?Sized + VTableMeta> VRefMut<'a, T> { pub unsafe fn get_ptr(x: &Self) -> NonNull { T::get_ptr(&x.inner) } - pub unsafe fn get_vtable(x: &Self) -> NonNull { - T::get_vtable(&x.inner) - } - pub unsafe fn from_raw(vtable: NonNull, ptr: NonNull) -> Self { - Self {inner : T::from_raw(vtable, ptr), _phantom: PhantomData } - } pub fn borrow<'b>(&'b self) -> VRef<'b, T> { unsafe { VRef::from_inner(VRefMut::inner(self)) } } @@ -206,7 +189,7 @@ impl<'a, T: ?Sized + VTableMeta> VRefMut<'a, T> { pub fn into_ref(self) -> VRef<'a, T> { unsafe { VRef::from_inner(VRefMut::inner(&self)) } } - pub fn get_type(&self) -> T::Type { - unsafe { T::get_type(&self.inner) } + pub fn get_vtable(&self) -> &T::VTable { + unsafe { T::get_vtable(&self.inner) } } } diff --git a/helper_crates/vtable/tests/test_vtable.rs b/helper_crates/vtable/tests/test_vtable.rs index 7422dc6b6..81fff7087 100644 --- a/helper_crates/vtable/tests/test_vtable.rs +++ b/helper_crates/vtable/tests/test_vtable.rs @@ -40,18 +40,18 @@ impl HelloConsts for SomeStruct { const CONSTANT: usize = 88; } -static SOME_STRUCT_TYPE : HelloType = HelloVTable_static!(SomeStruct); +static SOME_STRUCT_TYPE : HelloVTable = HelloVTable_static!(SomeStruct); #[test] fn test() { let vt = &SOME_STRUCT_TYPE; assert_eq!(vt.assoc(), 32); - assert_eq!(vt.CONSTANT(), 88); + assert_eq!(vt.CONSTANT, 88); let mut bx = vt.construct(89); assert_eq!(bx.foo(1), 90); assert_eq!(bx.foo_mut(6), 95); assert_eq!(bx.foo(2), 97); - assert_eq!(bx.get_type().CONSTANT(), 88); + assert_eq!(bx.get_vtable().CONSTANT, 88); }