Fix scalar API in extensions, add some error handling

This commit is contained in:
PThorpe92 2025-01-18 15:19:23 -05:00
parent 4a41736f89
commit 956320b7d0
No known key found for this signature in database
GPG key ID: 66DB3FBACBDD05CC
11 changed files with 609 additions and 604 deletions

View file

@ -1,5 +1,7 @@
use syn::parse::ParseStream;
use syn::punctuated::Punctuated;
use syn::{Ident, Token};
use syn::token::Eq;
use syn::{Ident, LitStr, Token};
pub(crate) struct RegisterExtensionInput {
pub aggregates: Vec<Ident>,
@ -44,3 +46,39 @@ impl syn::parse::Parse for RegisterExtensionInput {
})
}
}
pub(crate) struct ScalarInfo {
pub name: String,
pub alias: Option<String>,
}
impl ScalarInfo {
pub fn new(name: String, alias: Option<String>) -> Self {
Self { name, alias }
}
}
impl syn::parse::Parse for ScalarInfo {
fn parse(input: ParseStream) -> syn::parse::Result<Self> {
let mut name = None;
let mut alias = None;
while !input.is_empty() {
if let Ok(ident) = input.parse::<Ident>() {
if ident.to_string().as_str() == "name" {
let _ = input.parse::<Eq>();
name = Some(input.parse::<LitStr>()?);
} else if ident.to_string().as_str() == "alias" {
let _ = input.parse::<Eq>();
alias = Some(input.parse::<LitStr>()?);
}
}
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
}
let Some(name) = name else {
return Err(input.error("Expected name"));
};
Ok(Self::new(name.value(), alias.map(|i| i.value())))
}
}

View file

@ -1,7 +1,7 @@
mod args;
use args::RegisterExtensionInput;
use args::{RegisterExtensionInput, ScalarInfo};
use quote::{format_ident, quote};
use syn::{parse_macro_input, DeriveInput};
use syn::{parse_macro_input, DeriveInput, ItemFn};
extern crate proc_macro;
use proc_macro::{token_stream::IntoIter, Group, TokenStream, TokenTree};
use std::collections::HashMap;
@ -138,65 +138,61 @@ fn generate_get_description(
enum_impl.parse().unwrap()
}
#[proc_macro_derive(ScalarDerive)]
pub fn derive_scalar(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
let struct_name = &ast.ident;
let register_fn_name = format_ident!("register_{}", struct_name);
let exec_fn_name = format_ident!("{}_exec", struct_name);
let alias_check = quote! {
if let Some(alias) = scalar.alias() {
let alias_c_name = std::ffi::CString::new(alias).unwrap();
#[proc_macro_attribute]
pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as ItemFn);
let fn_name = &ast.sig.ident;
let scalar_info = parse_macro_input!(attr as ScalarInfo);
let name = &scalar_info.name;
let register_fn_name = format_ident!("register_{}", fn_name);
let fn_body = &ast.block;
let alias_check = if let Some(alias) = &scalar_info.alias {
quote! {
let Ok(alias_c_name) = std::ffi::CString::new(#alias) else {
return ::limbo_ext::ResultCode::Error;
};
(api.register_scalar_function)(
api.ctx,
alias_c_name.as_ptr(),
#exec_fn_name,
#fn_name,
);
}
} else {
quote! {}
};
let expanded = quote! {
impl #struct_name {
#[no_mangle]
pub unsafe extern "C" fn #register_fn_name(
api: *const ::limbo_ext::ExtensionApi
) -> ::limbo_ext::ResultCode {
if api.is_null() {
return ::limbo_ext::RESULT_ERROR;
}
let api = unsafe { &*api };
let scalar = #struct_name;
let name = scalar.name();
let c_name = std::ffi::CString::new(name).unwrap();
(api.register_scalar_function)(
api.ctx,
c_name.as_ptr(),
#exec_fn_name,
);
#alias_check
::limbo_ext::RESULT_OK
#[no_mangle]
pub unsafe extern "C" fn #register_fn_name(
api: *const ::limbo_ext::ExtensionApi
) -> ::limbo_ext::ResultCode {
if api.is_null() {
return ::limbo_ext::ResultCode::Error;
}
let api = unsafe { &*api };
let Ok(c_name) = std::ffi::CString::new(#name) else {
return ::limbo_ext::ResultCode::Error;
};
(api.register_scalar_function)(
api.ctx,
c_name.as_ptr(),
#fn_name,
);
#alias_check
::limbo_ext::ResultCode::OK
}
#[no_mangle]
pub unsafe extern "C" fn #exec_fn_name(
pub unsafe extern "C" fn #fn_name(
argc: i32,
argv: *const ::limbo_ext::Value
) -> ::limbo_ext::Value {
let scalar = #struct_name;
let args_slice = if argv.is_null() || argc <= 0 {
let args = if argv.is_null() || argc <= 0 {
&[]
} else {
unsafe { std::slice::from_raw_parts(argv, argc as usize) }
};
scalar.call(args_slice)
#fn_body
}
};
@ -254,21 +250,20 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream {
api: *const ::limbo_ext::ExtensionApi
) -> ::limbo_ext::ResultCode {
if api.is_null() {
return ::limbo_ext::RESULT_ERROR;
return ::limbo_ext::ResultCode::Error;
}
let api = &*api;
let agg = #struct_name;
let name_str = agg.name();
let name_str = #struct_name::NAME;
let c_name = match std::ffi::CString::new(name_str) {
Ok(cname) => cname,
Err(_) => return ::limbo_ext::RESULT_ERROR,
Err(_) => return ::limbo_ext::ResultCode::Error,
};
(api.register_aggregate_function)(
api.ctx,
c_name.as_ptr(),
agg.args(),
#struct_name::ARGS,
#struct_name::#init_fn_name
as ::limbo_ext::InitAggFunction,
#struct_name::#step_fn_name
@ -297,8 +292,8 @@ pub fn register_extension(input: TokenStream) -> TokenStream {
syn::Ident::new(&format!("register_{}", scalar_ident), scalar_ident.span());
quote! {
{
let result = unsafe { #scalar_ident::#register_fn(api)};
if result != 0 {
let result = unsafe { #register_fn(api)};
if !result.is_ok() {
return result;
}
}
@ -310,7 +305,7 @@ pub fn register_extension(input: TokenStream) -> TokenStream {
quote! {
{
let result = unsafe{ #agg_ident::#register_fn(api)};
if result != 0 {
if !result.is_ok() {
return result;
}
}
@ -319,13 +314,13 @@ pub fn register_extension(input: TokenStream) -> TokenStream {
let expanded = quote! {
#[no_mangle]
pub extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> i32 {
pub extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode {
let api = unsafe { &*api };
#(#scalar_calls)*
#(#aggregate_calls)*
::limbo_ext::RESULT_OK
::limbo_ext::ResultCode::OK
}
};