Have args macro in extension take a range

This commit is contained in:
PThorpe92 2025-01-12 12:19:05 -05:00
parent c565fba195
commit 852817c9ff
No known key found for this signature in database
GPG key ID: 66DB3FBACBDD05CC
2 changed files with 50 additions and 10 deletions

View file

@ -6,26 +6,61 @@ register_extension! {
scalars: {
"uuid4_str" => uuid4_str,
"uuid4" => uuid4_blob,
"uuid7_str" => uuid7_str,
"uuid7" => uuid7_blob,
"uuid_str" => uuid_str,
"uuid_blob" => uuid_blob,
},
}
declare_scalar_functions! {
#[args(min = 0, max = 0)]
#[args(0)]
fn uuid4_str(_args: &[Value]) -> Value {
let uuid = uuid::Uuid::new_v4().to_string();
Value::from_text(uuid)
}
#[args(min = 0, max = 0)]
#[args(0)]
fn uuid4_blob(_args: &[Value]) -> Value {
let uuid = uuid::Uuid::new_v4();
let bytes = uuid.as_bytes();
Value::from_blob(bytes.to_vec())
}
#[args(min = 1, max = 1)]
#[args(0..=1)]
fn uuid7_str(args: &[Value]) -> Value {
let timestamp = if args.is_empty() {
let ctx = uuid::ContextV7::new();
uuid::Timestamp::now(ctx)
} else if args[0].value_type == limbo_extension::ValueType::Integer {
let ctx = uuid::ContextV7::new();
let int = args[0].value as i64;
uuid::Timestamp::from_unix(ctx, int as u64, 0)
} else {
return Value::null();
};
let uuid = uuid::Uuid::new_v7(timestamp);
Value::from_text(uuid.to_string())
}
#[args(0..=1)]
fn uuid7_blob(args: &[Value]) -> Value {
let timestamp = if args.is_empty() {
let ctx = uuid::ContextV7::new();
uuid::Timestamp::now(ctx)
} else if args[0].value_type == limbo_extension::ValueType::Integer {
let ctx = uuid::ContextV7::new();
let int = args[0].value as i64;
uuid::Timestamp::from_unix(ctx, int as u64, 0)
} else {
return Value::null();
};
let uuid = uuid::Uuid::new_v7(timestamp);
let bytes = uuid.as_bytes();
Value::from_blob(bytes.to_vec())
}
#[args(1)]
fn uuid_str(args: &[Value]) -> Value {
if args[0].value_type != limbo_extension::ValueType::Blob {
log::debug!("uuid_str was passed a non-blob arg");
@ -43,7 +78,7 @@ declare_scalar_functions! {
}
}
#[args(min = 1, max = 1)]
#[args(1)]
fn uuid_blob(args: &[Value]) -> Value {
if args[0].value_type != limbo_extension::ValueType::Text {
log::debug!("uuid_blob was passed a non-text arg");

View file

@ -53,6 +53,7 @@ macro_rules! register_scalar_functions {
/// Provide a cleaner interface to define scalar functions to extension authors
/// . e.g.
/// ```
/// #[args(1)]
/// fn scalar_func(args: &[Value]) -> Value {
/// if args.len() != 1 {
/// return Value::null();
@ -65,7 +66,7 @@ macro_rules! register_scalar_functions {
macro_rules! declare_scalar_functions {
(
$(
#[args(min = $min_args:literal, max = $max_args:literal)]
#[args($($args_count:tt)+)]
fn $func_name:ident ($args:ident : &[Value]) -> Value $body:block
)*
) => {
@ -74,28 +75,32 @@ macro_rules! declare_scalar_functions {
argc: i32,
argv: *const *const std::os::raw::c_void
) -> $crate::Value {
if !($min_args..=$max_args).contains(&argc) {
let valid_args = {
match argc {
$($args_count)+ => true,
_ => false,
}
};
if !valid_args {
return $crate::Value::null();
}
if argc == 0 || argv.is_null() {
let $args: &[$crate::Value] = &[];
$body
} else {
unsafe {
let ptr_slice = std::slice::from_raw_parts(argv, argc as usize);
let ptr_slice = unsafe{ std::slice::from_raw_parts(argv, argc as usize)};
let mut values = Vec::with_capacity(argc as usize);
for &ptr in ptr_slice {
let val_ptr = ptr as *const $crate::Value;
if val_ptr.is_null() {
values.push($crate::Value::null());
} else {
values.push(std::ptr::read(val_ptr));
unsafe{values.push(std::ptr::read(val_ptr))};
}
}
let $args: &[$crate::Value] = &values[..];
$body
}
}
}
)*
};