refactor: more well rounded implementation

`?0` parameters are now handled by the parser.
This commit is contained in:
Levy A. 2025-01-15 16:33:33 -03:00
parent 5de2694834
commit 9b8722f38e
7 changed files with 109 additions and 189 deletions

View file

@ -4,6 +4,7 @@ mod function;
mod io;
#[cfg(feature = "json")]
mod json;
mod parameters;
mod pseudo;
mod result;
mod schema;
@ -44,7 +45,7 @@ use util::parse_schema_rows;
pub use error::LimboError;
use translate::select::prepare_select_plan;
pub type Result<T> = std::result::Result<T, error::LimboError>;
pub type Result<T, E = error::LimboError> = std::result::Result<T, E>;
use crate::translate::optimizer::optimize_plan;
pub use io::OpenFlags;
@ -475,16 +476,8 @@ impl Statement {
Ok(Rows::new(stmt))
}
pub fn parameter_count(&mut self) -> usize {
self.program.parameter_count()
}
pub fn parameter_name(&self, index: NonZero<usize>) -> Option<String> {
self.program.parameter_name(index)
}
pub fn parameter_index(&self, name: impl AsRef<str>) -> Option<NonZero<usize>> {
self.program.parameter_index(name)
pub fn parameters(&self) -> &parameters::Parameters {
&self.program.parameters
}
pub fn bind_at(&mut self, index: NonZero<usize>, value: Value) {
@ -492,7 +485,8 @@ impl Statement {
}
pub fn reset(&mut self) {
self.state.reset();
let state = vdbe::ProgramState::new(self.program.max_registers);
self.state = state
}
}

View file

@ -1710,8 +1710,8 @@ pub fn translate_expr(
}
_ => todo!(),
},
ast::Expr::Variable(_) => {
let index = program.pop_index();
ast::Expr::Variable(name) => {
let index = program.parameters.push(name);
program.emit_insn(Insn::Variable {
index,
dest: target_register,

View file

@ -32,121 +32,12 @@ use crate::vdbe::{builder::ProgramBuilder, insn::Insn, Program};
use crate::{bail_parse_error, Connection, LimboError, Result, SymbolTable};
use insert::translate_insert;
use select::translate_select;
use sqlite3_parser::ast::fmt::TokenStream;
use sqlite3_parser::ast::{self, fmt::ToTokens, PragmaName};
use sqlite3_parser::dialect::TokenType;
use std::cell::RefCell;
use std::fmt::Display;
use std::num::NonZero;
use std::rc::{Rc, Weak};
use std::str::FromStr;
#[derive(Clone, Debug)]
pub enum Parameter {
Anonymous(NonZero<usize>),
Indexed(NonZero<usize>),
Named(String, NonZero<usize>),
}
impl PartialEq for Parameter {
fn eq(&self, other: &Self) -> bool {
self.index() == other.index()
}
}
impl Parameter {
pub fn index(&self) -> NonZero<usize> {
match self {
Parameter::Anonymous(index) => *index,
Parameter::Indexed(index) => *index,
Parameter::Named(_, index) => *index,
}
}
}
/// `?` or `$` Prepared statement arg placeholder(s)
#[derive(Debug)]
pub struct Parameters {
index: NonZero<usize>,
pub list: Vec<Parameter>,
}
impl Parameters {
pub fn new() -> Self {
Self {
index: 1.try_into().unwrap(),
list: vec![],
}
}
pub fn push(&mut self, value: Parameter) {
self.list.push(value);
}
pub fn next_index(&mut self) -> NonZero<usize> {
let index = self.index;
self.index = self.index.checked_add(1).unwrap();
index
}
pub fn get(&mut self, index: usize) -> Option<&Parameter> {
self.list.get(index)
}
}
// https://sqlite.org/lang_expr.html#parameters
impl TokenStream for Parameters {
type Error = std::convert::Infallible;
fn append(
&mut self,
ty: TokenType,
value: Option<&str>,
) -> std::result::Result<(), Self::Error> {
if ty == TokenType::TK_VARIABLE {
if let Some(variable) = value {
match variable.split_at(1) {
("?", "") => {
let index = self.next_index();
self.push(Parameter::Anonymous(index.try_into().unwrap()));
log::trace!("anonymous parameter at {index}");
}
("?", index) => {
let index: NonZero<usize> = index.parse().unwrap();
if index > self.index {
self.index = index.checked_add(1).unwrap();
}
self.push(Parameter::Indexed(index.try_into().unwrap()));
log::trace!("indexed parameter at {index}");
}
(_, _) => {
match self.list.iter().find(|p| {
let Parameter::Named(name, _) = p else {
return false;
};
name == variable
}) {
Some(t) => {
log::trace!("named parameter at {} as {}", t.index(), variable);
self.push(t.clone());
}
None => {
let index = self.next_index();
self.push(Parameter::Named(
variable.to_owned(),
index.try_into().unwrap(),
));
log::trace!("named parameter at {index} as {variable}");
}
}
}
}
}
}
Ok(())
}
}
/// Translate SQL statement into bytecode program.
pub fn translate(
schema: &Schema,
@ -156,13 +47,7 @@ pub fn translate(
connection: Weak<Connection>,
syms: &SymbolTable,
) -> Result<Program> {
let mut parameters = Parameters::new();
stmt.to_tokens(&mut parameters).unwrap();
// dbg!(&parameters);
// dbg!(&parameters.list.clone().dedup());
let mut program = ProgramBuilder::new(parameters);
let mut program = ProgramBuilder::new();
match stmt {
ast::Stmt::AlterTable(_, _) => bail_parse_error!("ALTER TABLE not supported yet"),

View file

@ -1,18 +1,17 @@
use std::{
cell::RefCell,
collections::HashMap,
num::NonZero,
rc::{Rc, Weak},
};
use crate::{
parameters::Parameters,
schema::{BTreeTable, Index, PseudoTable},
storage::sqlite3_ondisk::DatabaseHeader,
Connection,
};
use super::{BranchOffset, CursorID, Insn, InsnReference, Program};
#[allow(dead_code)]
pub struct ProgramBuilder {
next_free_register: usize,
@ -30,10 +29,7 @@ pub struct ProgramBuilder {
seekrowid_emitted_bitmask: u64,
// map of instruction index to manual comment (used in EXPLAIN)
comments: HashMap<InsnReference, &'static str>,
named_parameters: HashMap<String, NonZero<usize>>,
next_free_parameter_index: NonZero<usize>,
parameters: crate::translate::Parameters,
parameter_index: usize,
pub parameters: Parameters,
}
#[derive(Debug, Clone)]
@ -51,12 +47,11 @@ impl CursorType {
}
impl ProgramBuilder {
pub fn new(parameters: crate::translate::Parameters) -> Self {
pub fn new() -> Self {
Self {
next_free_register: 1,
next_free_label: 0,
next_free_cursor_id: 0,
next_free_parameter_index: 1.try_into().unwrap(),
insns: Vec::new(),
next_insn_label: None,
cursor_ref: Vec::new(),
@ -64,9 +59,7 @@ impl ProgramBuilder {
label_to_resolved_offset: HashMap::new(),
seekrowid_emitted_bitmask: 0,
comments: HashMap::new(),
named_parameters: HashMap::new(),
parameters,
parameter_index: 0,
parameters: Parameters::new(),
}
}
@ -349,13 +342,7 @@ impl ProgramBuilder {
comments: self.comments,
connection,
auto_commit: true,
parameters: self.parameters.list,
parameters: self.parameters,
}
}
pub fn pop_index(&mut self) -> NonZero<usize> {
let parameter = self.parameters.get(self.parameter_index).unwrap();
self.parameter_index += 1;
return parameter.index();
}
}

View file

@ -277,7 +277,7 @@ pub struct Program {
pub cursor_ref: Vec<(Option<String>, CursorType)>,
pub database_header: Rc<RefCell<DatabaseHeader>>,
pub comments: HashMap<InsnReference, &'static str>,
pub parameters: Vec<crate::translate::Parameter>,
pub parameters: crate::parameters::Parameters,
pub connection: Weak<Connection>,
pub auto_commit: bool,
}
@ -301,38 +301,6 @@ impl Program {
}
}
pub fn parameter_count(&self) -> usize {
self.parameters.len()
}
pub fn parameter_name(&self, index: NonZero<usize>) -> Option<String> {
use crate::translate::Parameter;
self.parameters.iter().find_map(|p| match p {
Parameter::Anonymous(i) if *i == index => Some("?".to_string()),
Parameter::Indexed(i) if *i == index => Some(format!("?{i}")),
Parameter::Named(name, i) if *i == index => Some(name.to_owned()),
_ => None,
})
}
pub fn parameter_index(&self, name: impl AsRef<str>) -> Option<NonZero<usize>> {
use crate::translate::Parameter;
self.parameters
.iter()
.find_map(|p| {
let Parameter::Named(parameter_name, index) = p else {
return None;
};
if name.as_ref() == parameter_name {
return Some(index);
}
None
})
.copied()
}
pub fn step<'a>(
&self,
state: &'a mut ProgramState,

View file

@ -573,22 +573,99 @@ mod tests {
Ok(())
}
#[test]
fn test_statement_reset() -> anyhow::Result<()> {
let _ = env_logger::try_init();
let tmp_db = TempDatabase::new("create table test (i integer);");
let conn = tmp_db.connect_limbo();
conn.execute("insert into test values (1)")?;
conn.execute("insert into test values (2)")?;
let mut stmt = conn.prepare("select * from test")?;
loop {
match stmt.step()? {
StepResult::Row(row) => {
assert_eq!(row.values[0], Value::Integer(1));
break;
}
StepResult::IO => tmp_db.io.run_once()?,
_ => break,
}
}
stmt.reset();
loop {
match stmt.step()? {
StepResult::Row(row) => {
assert_eq!(row.values[0], Value::Integer(1));
break;
}
StepResult::IO => tmp_db.io.run_once()?,
_ => break,
}
}
Ok(())
}
#[test]
fn test_statement_reset_bind() -> anyhow::Result<()> {
let _ = env_logger::try_init();
let tmp_db = TempDatabase::new("create table test (i integer);");
let conn = tmp_db.connect_limbo();
let mut stmt = conn.prepare("select ?")?;
stmt.bind_at(1.try_into().unwrap(), Value::Integer(1));
loop {
match stmt.step()? {
StepResult::Row(row) => {
assert_eq!(row.values[0], Value::Integer(1));
}
StepResult::IO => tmp_db.io.run_once()?,
_ => break,
}
}
stmt.reset();
stmt.bind_at(1.try_into().unwrap(), Value::Integer(2));
loop {
match stmt.step()? {
StepResult::Row(row) => {
assert_eq!(row.values[0], Value::Integer(2));
}
StepResult::IO => tmp_db.io.run_once()?,
_ => break,
}
}
Ok(())
}
#[test]
fn test_statement_bind() -> anyhow::Result<()> {
let _ = env_logger::try_init();
let tmp_db = TempDatabase::new("create table test (i integer);");
let conn = tmp_db.connect_limbo();
let mut stmt = conn.prepare("select ?, ?1, :named, ?4")?;
let mut stmt = conn.prepare("select ?, ?1, :named, ?3, ?4")?;
stmt.bind_at(1.try_into().unwrap(), Value::Text(&"hello".to_string()));
let i = stmt.parameter_index(":named").unwrap();
let i = stmt.parameters().index(":named").unwrap();
stmt.bind_at(i, Value::Integer(42));
stmt.bind_at(3.try_into().unwrap(), Value::Blob(&vec![0x1, 0x2, 0x3]));
stmt.bind_at(4.try_into().unwrap(), Value::Float(0.5));
assert_eq!(stmt.parameter_count(), 3);
assert_eq!(stmt.parameters().count(), 4);
loop {
match stmt.step()? {
@ -601,12 +678,16 @@ mod tests {
assert_eq!(s, "hello")
}
if let Value::Integer(s) = row.values[2] {
assert_eq!(s, 42)
if let Value::Integer(i) = row.values[2] {
assert_eq!(i, 42)
}
if let Value::Float(s) = row.values[3] {
assert_eq!(s, 0.5)
if let Value::Blob(v) = row.values[3] {
assert_eq!(v, &vec![0x1 as u8, 0x2, 0x3])
}
if let Value::Float(f) = row.values[4] {
assert_eq!(f, 0.5)
}
}
StepResult::IO => {

View file

@ -441,7 +441,12 @@ impl Splitter for Tokenizer {
// do not include the '?' in the token
Ok((Some((&data[1..=i], TK_VARIABLE)), i + 1))
}
None => Ok((Some((&data[1..], TK_VARIABLE)), data.len())),
None => {
if !data[1..].is_empty() && data[1..].iter().all(|ch| *ch == b'0') {
return Err(Error::BadVariableName(None, None));
}
Ok((Some((&data[1..], TK_VARIABLE)), data.len()))
}
}
}
b'$' | b'@' | b'#' | b':' => {