Create extern functions to support vtab xConnect in core/ext

This commit is contained in:
PThorpe92 2025-05-23 19:59:47 -04:00
parent 02e7726249
commit d51614a4fd
No known key found for this signature in database
GPG key ID: 66DB3FBACBDD05CC
2 changed files with 210 additions and 0 deletions

179
core/ext/vtab_xconnect.rs Normal file
View file

@ -0,0 +1,179 @@
use crate::{types::Value, Connection, Statement, StepResult};
use limbo_ext::{Conn as ExtConn, ResultCode, Stmt, Value as ExtValue};
use std::{
boxed::Box,
ffi::{c_char, c_void, CStr, CString},
num::NonZeroUsize,
ptr,
rc::Weak,
};
pub unsafe extern "C" fn close(ctx: *mut c_void) {
if ctx.is_null() {
return;
}
let weak_box: Box<Weak<Connection>> = Box::from_raw(ctx as *mut Weak<Connection>);
if let Some(conn) = weak_box.upgrade() {
let _ = conn.close();
}
}
pub unsafe extern "C" fn prepare_stmt(ctx: *mut ExtConn, sql: *const c_char) -> *const Stmt {
let c_str = unsafe { CStr::from_ptr(sql as *mut c_char) };
let sql_str = match c_str.to_str() {
Ok(s) => s.to_string(),
Err(_) => return ptr::null_mut(),
};
if ctx.is_null() {
return ptr::null_mut();
}
let Ok(extcon) = ExtConn::from_ptr(ctx) else {
return ptr::null_mut();
};
let weak_ptr = extcon._ctx as *const Weak<Connection>;
let weak = &*weak_ptr;
let Some(conn) = weak.upgrade() else {
return ptr::null_mut();
};
match conn.prepare(&sql_str) {
Ok(stmt) => {
let raw_stmt = Box::into_raw(Box::new(stmt)) as *mut c_void;
Box::into_raw(Box::new(Stmt::new(
extcon._ctx,
raw_stmt,
stmt_bind_args_fn,
stmt_step,
stmt_get_row,
stmt_get_column_names,
stmt_free_current_row,
stmt_close,
))) as *const Stmt
}
Err(_) => ptr::null_mut(),
}
}
pub unsafe extern "C" fn stmt_bind_args_fn(
ctx: *mut Stmt,
idx: i32,
arg: *const ExtValue,
) -> ResultCode {
let Ok(stmt) = Stmt::from_ptr(ctx) else {
return ResultCode::Error;
};
let stmt_ctx: &mut Statement = unsafe { &mut *(stmt._ctx as *mut Statement) };
let Ok(owned_val) = Value::from_ffi_ptr(arg) else {
tracing::error!("stmt_bind_args_fn: failed to convert arg to Value");
return ResultCode::Error;
};
let Some(idx) = NonZeroUsize::new(idx as usize) else {
tracing::error!("stmt_bind_args_fn: invalid index");
return ResultCode::Error;
};
stmt_ctx.bind_at(idx, owned_val);
ResultCode::OK
}
pub unsafe extern "C" fn stmt_step(stmt: *mut Stmt) -> ResultCode {
let Ok(stmt) = Stmt::from_ptr(stmt) else {
tracing::error!("stmt_step: failed to convert stmt to Stmt");
return ResultCode::Error;
};
if stmt._conn.is_null() || stmt._ctx.is_null() {
tracing::error!("stmt_step: null connection or context");
return ResultCode::Error;
}
let conn: &Connection = unsafe { &*(stmt._conn as *const Connection) };
let stmt_ctx: &mut Statement = unsafe { &mut *(stmt._ctx as *mut Statement) };
while let Ok(res) = stmt_ctx.step() {
match res {
StepResult::Row => return ResultCode::Row,
StepResult::Done => return ResultCode::EOF,
StepResult::IO => {
// always handle IO step result internally.
let _ = conn.pager.io.run_once();
continue;
}
StepResult::Interrupt => return ResultCode::Interrupt,
StepResult::Busy => return ResultCode::Busy,
}
}
ResultCode::Error
}
pub unsafe extern "C" fn stmt_get_row(ctx: *mut Stmt) {
let Ok(stmt) = Stmt::from_ptr(ctx) else {
return;
};
if !stmt.current_row.is_null() {
stmt.free_current_row();
}
let stmt_ctx: &mut Statement = unsafe { &mut *(stmt._ctx as *mut Statement) };
if let Some(row) = stmt_ctx.row() {
let values = row.get_values();
let mut owned_values = Vec::with_capacity(row.len());
for value in values {
owned_values.push(Value::to_ffi(value));
}
stmt.current_row = Box::into_raw(owned_values.into_boxed_slice()) as *mut ExtValue;
stmt.current_row_len = row.len() as i32;
} else {
stmt.current_row_len = 0;
}
}
pub unsafe extern "C" fn stmt_free_current_row(ctx: *mut Stmt) {
let Ok(stmt) = Stmt::from_ptr(ctx) else {
return;
};
if !stmt.current_row.is_null() {
let values: &mut [ExtValue] =
std::slice::from_raw_parts_mut(stmt.current_row, stmt.current_row_len as usize);
for value in values.iter_mut() {
let owned_value = std::mem::take(value);
owned_value.__free_internal_type();
}
let _ = Box::from_raw(stmt.current_row);
}
}
pub unsafe extern "C" fn stmt_get_column_names(
ctx: *mut Stmt,
count: *mut i32,
) -> *mut *mut c_char {
let Ok(stmt) = Stmt::from_ptr(ctx) else {
*count = 0;
return ptr::null_mut();
};
let stmt_ctx: &mut Statement = unsafe { &mut *(stmt._ctx as *mut Statement) };
let num_cols = stmt_ctx.num_columns();
if num_cols == 0 {
*count = 0;
return ptr::null_mut();
}
let mut c_names: Vec<*mut c_char> = Vec::with_capacity(num_cols);
for i in 0..num_cols {
let name = stmt_ctx.get_column_name(i);
let c_str = CString::new(name.as_bytes()).unwrap();
c_names.push(c_str.into_raw());
}
*count = c_names.len() as i32;
let names_array = c_names.into_boxed_slice();
Box::into_raw(names_array) as *mut *mut c_char
}
pub unsafe extern "C" fn stmt_close(ctx: *mut Stmt) {
let Ok(stmt) = Stmt::from_ptr(ctx) else {
return;
};
if !stmt.current_row.is_null() {
stmt.free_current_row();
}
// take ownership of internal statement
let wrapper = Box::from_raw(stmt as *mut Stmt);
if !wrapper._ctx.is_null() {
let mut _stmt: Box<Statement> = Box::from_raw(wrapper._ctx as *mut Statement);
_stmt.reset()
}
}

View file

@ -23,6 +23,9 @@ pub enum ResultCode {
EOF = 15,
ReadOnly = 16,
RowID = 17,
Row = 18,
Interrupt = 19,
Busy = 20,
}
impl ResultCode {
@ -60,6 +63,34 @@ impl Display for ResultCode {
ResultCode::EOF => write!(f, "EOF"),
ResultCode::ReadOnly => write!(f, "Read Only"),
ResultCode::RowID => write!(f, "RowID"),
ResultCode::Row => write!(f, "Row"),
ResultCode::Interrupt => write!(f, "Interrupt"),
ResultCode::Busy => write!(f, "Busy"),
}
}
}
#[repr(C)]
#[derive(PartialEq, Debug, Eq, Clone, Copy)]
/// StepResult is used to represent the state of a query as it is exposed
/// to the public API of a connection in a virtual table extension.
/// the IO variant is always handled internally and therefore is not included here.
pub enum StepResult {
Error,
Row,
Done,
Interrupt,
Busy,
}
impl From<ResultCode> for StepResult {
fn from(code: ResultCode) -> Self {
match code {
ResultCode::Error => StepResult::Error,
ResultCode::Row => StepResult::Row,
ResultCode::EOF => StepResult::Done,
ResultCode::Interrupt => StepResult::Interrupt,
ResultCode::Busy => StepResult::Busy,
_ => StepResult::Error,
}
}
}