mirror of
https://github.com/tursodatabase/limbo.git
synced 2025-08-04 18:18:03 +00:00
Fix scalar API in extensions, add some error handling
This commit is contained in:
parent
4a41736f89
commit
956320b7d0
11 changed files with 609 additions and 604 deletions
|
@ -1,5 +1,7 @@
|
|||
use syn::parse::ParseStream;
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::{Ident, Token};
|
||||
use syn::token::Eq;
|
||||
use syn::{Ident, LitStr, Token};
|
||||
|
||||
pub(crate) struct RegisterExtensionInput {
|
||||
pub aggregates: Vec<Ident>,
|
||||
|
@ -44,3 +46,39 @@ impl syn::parse::Parse for RegisterExtensionInput {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ScalarInfo {
|
||||
pub name: String,
|
||||
pub alias: Option<String>,
|
||||
}
|
||||
|
||||
impl ScalarInfo {
|
||||
pub fn new(name: String, alias: Option<String>) -> Self {
|
||||
Self { name, alias }
|
||||
}
|
||||
}
|
||||
|
||||
impl syn::parse::Parse for ScalarInfo {
|
||||
fn parse(input: ParseStream) -> syn::parse::Result<Self> {
|
||||
let mut name = None;
|
||||
let mut alias = None;
|
||||
while !input.is_empty() {
|
||||
if let Ok(ident) = input.parse::<Ident>() {
|
||||
if ident.to_string().as_str() == "name" {
|
||||
let _ = input.parse::<Eq>();
|
||||
name = Some(input.parse::<LitStr>()?);
|
||||
} else if ident.to_string().as_str() == "alias" {
|
||||
let _ = input.parse::<Eq>();
|
||||
alias = Some(input.parse::<LitStr>()?);
|
||||
}
|
||||
}
|
||||
if input.peek(Token![,]) {
|
||||
input.parse::<Token![,]>()?;
|
||||
}
|
||||
}
|
||||
let Some(name) = name else {
|
||||
return Err(input.error("Expected name"));
|
||||
};
|
||||
Ok(Self::new(name.value(), alias.map(|i| i.value())))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
mod args;
|
||||
use args::RegisterExtensionInput;
|
||||
use args::{RegisterExtensionInput, ScalarInfo};
|
||||
use quote::{format_ident, quote};
|
||||
use syn::{parse_macro_input, DeriveInput};
|
||||
use syn::{parse_macro_input, DeriveInput, ItemFn};
|
||||
extern crate proc_macro;
|
||||
use proc_macro::{token_stream::IntoIter, Group, TokenStream, TokenTree};
|
||||
use std::collections::HashMap;
|
||||
|
@ -138,65 +138,61 @@ fn generate_get_description(
|
|||
enum_impl.parse().unwrap()
|
||||
}
|
||||
|
||||
#[proc_macro_derive(ScalarDerive)]
|
||||
pub fn derive_scalar(input: TokenStream) -> TokenStream {
|
||||
let ast = parse_macro_input!(input as DeriveInput);
|
||||
let struct_name = &ast.ident;
|
||||
|
||||
let register_fn_name = format_ident!("register_{}", struct_name);
|
||||
let exec_fn_name = format_ident!("{}_exec", struct_name);
|
||||
|
||||
let alias_check = quote! {
|
||||
if let Some(alias) = scalar.alias() {
|
||||
let alias_c_name = std::ffi::CString::new(alias).unwrap();
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn scalar(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let ast = parse_macro_input!(input as ItemFn);
|
||||
let fn_name = &ast.sig.ident;
|
||||
let scalar_info = parse_macro_input!(attr as ScalarInfo);
|
||||
let name = &scalar_info.name;
|
||||
let register_fn_name = format_ident!("register_{}", fn_name);
|
||||
let fn_body = &ast.block;
|
||||
let alias_check = if let Some(alias) = &scalar_info.alias {
|
||||
quote! {
|
||||
let Ok(alias_c_name) = std::ffi::CString::new(#alias) else {
|
||||
return ::limbo_ext::ResultCode::Error;
|
||||
};
|
||||
(api.register_scalar_function)(
|
||||
api.ctx,
|
||||
alias_c_name.as_ptr(),
|
||||
#exec_fn_name,
|
||||
#fn_name,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
quote! {}
|
||||
};
|
||||
|
||||
let expanded = quote! {
|
||||
impl #struct_name {
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn #register_fn_name(
|
||||
api: *const ::limbo_ext::ExtensionApi
|
||||
) -> ::limbo_ext::ResultCode {
|
||||
if api.is_null() {
|
||||
return ::limbo_ext::RESULT_ERROR;
|
||||
}
|
||||
let api = unsafe { &*api };
|
||||
|
||||
let scalar = #struct_name;
|
||||
let name = scalar.name();
|
||||
let c_name = std::ffi::CString::new(name).unwrap();
|
||||
|
||||
(api.register_scalar_function)(
|
||||
api.ctx,
|
||||
c_name.as_ptr(),
|
||||
#exec_fn_name,
|
||||
);
|
||||
|
||||
#alias_check
|
||||
|
||||
::limbo_ext::RESULT_OK
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn #register_fn_name(
|
||||
api: *const ::limbo_ext::ExtensionApi
|
||||
) -> ::limbo_ext::ResultCode {
|
||||
if api.is_null() {
|
||||
return ::limbo_ext::ResultCode::Error;
|
||||
}
|
||||
let api = unsafe { &*api };
|
||||
let Ok(c_name) = std::ffi::CString::new(#name) else {
|
||||
return ::limbo_ext::ResultCode::Error;
|
||||
};
|
||||
(api.register_scalar_function)(
|
||||
api.ctx,
|
||||
c_name.as_ptr(),
|
||||
#fn_name,
|
||||
);
|
||||
#alias_check
|
||||
::limbo_ext::ResultCode::OK
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn #exec_fn_name(
|
||||
pub unsafe extern "C" fn #fn_name(
|
||||
argc: i32,
|
||||
argv: *const ::limbo_ext::Value
|
||||
) -> ::limbo_ext::Value {
|
||||
let scalar = #struct_name;
|
||||
let args_slice = if argv.is_null() || argc <= 0 {
|
||||
let args = if argv.is_null() || argc <= 0 {
|
||||
&[]
|
||||
} else {
|
||||
unsafe { std::slice::from_raw_parts(argv, argc as usize) }
|
||||
};
|
||||
scalar.call(args_slice)
|
||||
#fn_body
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -254,21 +250,20 @@ pub fn derive_agg_func(input: TokenStream) -> TokenStream {
|
|||
api: *const ::limbo_ext::ExtensionApi
|
||||
) -> ::limbo_ext::ResultCode {
|
||||
if api.is_null() {
|
||||
return ::limbo_ext::RESULT_ERROR;
|
||||
return ::limbo_ext::ResultCode::Error;
|
||||
}
|
||||
|
||||
let api = &*api;
|
||||
let agg = #struct_name;
|
||||
let name_str = agg.name();
|
||||
let name_str = #struct_name::NAME;
|
||||
let c_name = match std::ffi::CString::new(name_str) {
|
||||
Ok(cname) => cname,
|
||||
Err(_) => return ::limbo_ext::RESULT_ERROR,
|
||||
Err(_) => return ::limbo_ext::ResultCode::Error,
|
||||
};
|
||||
|
||||
(api.register_aggregate_function)(
|
||||
api.ctx,
|
||||
c_name.as_ptr(),
|
||||
agg.args(),
|
||||
#struct_name::ARGS,
|
||||
#struct_name::#init_fn_name
|
||||
as ::limbo_ext::InitAggFunction,
|
||||
#struct_name::#step_fn_name
|
||||
|
@ -297,8 +292,8 @@ pub fn register_extension(input: TokenStream) -> TokenStream {
|
|||
syn::Ident::new(&format!("register_{}", scalar_ident), scalar_ident.span());
|
||||
quote! {
|
||||
{
|
||||
let result = unsafe { #scalar_ident::#register_fn(api)};
|
||||
if result != 0 {
|
||||
let result = unsafe { #register_fn(api)};
|
||||
if !result.is_ok() {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
@ -310,7 +305,7 @@ pub fn register_extension(input: TokenStream) -> TokenStream {
|
|||
quote! {
|
||||
{
|
||||
let result = unsafe{ #agg_ident::#register_fn(api)};
|
||||
if result != 0 {
|
||||
if !result.is_ok() {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
@ -319,13 +314,13 @@ pub fn register_extension(input: TokenStream) -> TokenStream {
|
|||
|
||||
let expanded = quote! {
|
||||
#[no_mangle]
|
||||
pub extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> i32 {
|
||||
pub extern "C" fn register_extension(api: &::limbo_ext::ExtensionApi) -> ::limbo_ext::ResultCode {
|
||||
let api = unsafe { &*api };
|
||||
#(#scalar_calls)*
|
||||
|
||||
#(#aggregate_calls)*
|
||||
|
||||
::limbo_ext::RESULT_OK
|
||||
::limbo_ext::ResultCode::OK
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue