Reduce memory allocations

Fixes #26
This commit is contained in:
Pekka Enberg 2024-01-28 09:00:38 +02:00
parent af258a4958
commit 505e28aaeb
5 changed files with 129 additions and 96 deletions

View file

@ -1,6 +1,6 @@
use crate::pager::Pager;
use crate::sqlite3_ondisk::{BTreeCell, TableInteriorCell, TableLeafCell};
use crate::types::Record;
use crate::types::OwnedRecord;
use anyhow::Result;
@ -42,7 +42,7 @@ pub struct Cursor {
root_page: usize,
page: RefCell<Option<Arc<MemPage>>>,
rowid: RefCell<Option<u64>>,
record: RefCell<Option<Record>>,
record: RefCell<Option<OwnedRecord>>,
}
impl Cursor {
@ -97,7 +97,7 @@ impl Cursor {
Ok(self.rowid.borrow())
}
pub fn record(&self) -> Result<Ref<Option<Record>>> {
pub fn record(&self) -> Result<Ref<Option<OwnedRecord>>> {
Ok(self.record.borrow())
}
@ -105,7 +105,7 @@ impl Cursor {
self.record.borrow().is_some()
}
fn get_next_record(&mut self) -> Result<CursorResult<(Option<u64>, Option<Record>)>> {
fn get_next_record(&mut self) -> Result<CursorResult<(Option<u64>, Option<OwnedRecord>)>> {
loop {
let mem_page = {
let mem_page = self.page.borrow();
@ -152,7 +152,7 @@ impl Cursor {
}
BTreeCell::TableLeafCell(TableLeafCell { _rowid, _payload }) => {
mem_page.advance();
let record = crate::sqlite3_ondisk::read_record(_payload)?;
let record= crate::sqlite3_ondisk::read_record(_payload)?;
return Ok(CursorResult::Ok((Some(*_rowid), Some(record))));
}
}

View file

@ -100,10 +100,7 @@ impl Connection {
match cmd {
Cmd::Stmt(stmt) => {
let program = Arc::new(translate::translate(&self.schema, stmt)?);
Ok(Statement {
program,
pager: self.pager.clone(),
})
Ok(Statement::new(program, self.pager.clone()))
}
Cmd::Explain(_stmt) => todo!(),
Cmd::ExplainQueryPlan(_stmt) => todo!(),
@ -121,8 +118,8 @@ impl Connection {
match cmd {
Cmd::Stmt(stmt) => {
let program = Arc::new(translate::translate(&self.schema, stmt)?);
let state = vdbe::ProgramState::new(program.max_registers);
Ok(Some(Rows::new(state, program, self.pager.clone())))
let stmt = Statement::new(program, self.pager.clone());
Ok(Some(Rows { stmt }))
}
Cmd::Explain(stmt) => {
let program = translate::translate(&self.schema, stmt)?;
@ -160,40 +157,21 @@ impl Connection {
pub struct Statement {
program: Arc<vdbe::Program>,
state: vdbe::ProgramState,
pager: Arc<Pager>,
}
impl Statement {
pub fn query(&self) -> Result<Rows> {
let state = vdbe::ProgramState::new(self.program.max_registers);
Ok(Rows::new(state, self.program.clone(), self.pager.clone()))
}
pub fn reset(&self) {}
}
pub enum RowResult {
Row(Row),
IO,
Done,
}
pub struct Rows {
state: vdbe::ProgramState,
program: Arc<vdbe::Program>,
pager: Arc<Pager>,
}
impl Rows {
pub fn new(state: vdbe::ProgramState, program: Arc<vdbe::Program>, pager: Arc<Pager>) -> Self {
pub fn new(program: Arc<vdbe::Program>, pager: Arc<Pager>) -> Self {
let state = vdbe::ProgramState::new(program.max_registers);
Self {
state,
program,
state,
pager,
}
}
pub fn next(&mut self) -> Result<RowResult> {
pub fn step<'a>(&'a mut self) -> Result<RowResult<'a>> {
loop {
let result = self.program.step(&mut self.state, self.pager.clone())?;
match result {
@ -209,15 +187,42 @@ impl Rows {
}
}
}
pub fn query(&mut self) -> Result<Rows> {
let stmt = Statement::new(self.program.clone(), self.pager.clone());
Ok(Rows::new(stmt))
}
pub fn reset(&self) {}
}
pub struct Row {
pub values: Vec<Value>,
pub enum RowResult<'a> {
Row(Row<'a>),
IO,
Done,
}
impl Row {
pub struct Row<'a> {
pub values: Vec<Value<'a>>,
}
impl<'a> Row<'a> {
pub fn get<T: crate::types::FromValue>(&self, idx: usize) -> Result<T> {
let value = &self.values[idx];
T::from_value(value)
}
}
pub struct Rows {
stmt: Statement,
}
impl Rows {
pub fn new(stmt: Statement) -> Self {
Self { stmt }
}
pub fn next<'a>(&'a mut self) -> Result<RowResult<'a>> {
self.stmt.step()
}
}

View file

@ -26,10 +26,9 @@
use crate::buffer_pool::BufferPool;
use crate::io::{Buffer, Completion};
use crate::pager::Page;
use crate::types::{Record, Value};
use crate::types::{OwnedRecord, OwnedValue};
use crate::PageSource;
use anyhow::{anyhow, Result};
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use log::trace;
@ -296,7 +295,7 @@ impl TryFrom<u64> for SerialType {
}
}
pub fn read_record(payload: &[u8]) -> Result<Record> {
pub fn read_record(payload: &[u8]) -> Result<OwnedRecord> {
let mut pos = 0;
let (header_size, nr) = read_varint(payload)?;
assert!((header_size as usize) >= nr);
@ -318,24 +317,24 @@ pub fn read_record(payload: &[u8]) -> Result<Record> {
pos += usize;
values.push(value);
}
Ok(Record { values })
Ok(OwnedRecord { values })
}
pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(Value, usize)> {
pub fn read_value(buf: & [u8], serial_type: &SerialType) -> Result<(OwnedValue, usize)> {
match *serial_type {
SerialType::Null => Ok((Value::Null, 0)),
SerialType::Null => Ok((OwnedValue::Null, 0)),
SerialType::UInt8 => {
if buf.len() < 1 {
return Err(anyhow!("Invalid UInt8 value"));
}
Ok((Value::Integer(buf[0] as i64), 1))
Ok((OwnedValue::Integer(buf[0] as i64), 1))
}
SerialType::BEInt16 => {
if buf.len() < 2 {
return Err(anyhow!("Invalid BEInt16 value"));
}
Ok((
Value::Integer(i16::from_be_bytes([buf[0], buf[1]]) as i64),
OwnedValue::Integer(i16::from_be_bytes([buf[0], buf[1]]) as i64),
2,
))
}
@ -344,7 +343,7 @@ pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(Value, usize)
return Err(anyhow!("Invalid BEInt24 value"));
}
Ok((
Value::Integer(i32::from_be_bytes([0, buf[0], buf[1], buf[2]]) as i64),
OwnedValue::Integer(i32::from_be_bytes([0, buf[0], buf[1], buf[2]]) as i64),
3,
))
}
@ -353,7 +352,7 @@ pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(Value, usize)
return Err(anyhow!("Invalid BEInt32 value"));
}
Ok((
Value::Integer(i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as i64),
OwnedValue::Integer(i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as i64),
4,
))
}
@ -362,7 +361,7 @@ pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(Value, usize)
return Err(anyhow!("Invalid BEInt48 value"));
}
Ok((
Value::Integer(i64::from_be_bytes([
OwnedValue::Integer(i64::from_be_bytes([
0, 0, buf[0], buf[1], buf[2], buf[3], buf[4], buf[5],
])),
6,
@ -373,7 +372,7 @@ pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(Value, usize)
return Err(anyhow!("Invalid BEInt64 value"));
}
Ok((
Value::Integer(i64::from_be_bytes([
OwnedValue::Integer(i64::from_be_bytes([
buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
])),
8,
@ -384,19 +383,19 @@ pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(Value, usize)
return Err(anyhow!("Invalid BEFloat64 value"));
}
Ok((
Value::Float(f64::from_be_bytes([
OwnedValue::Float(f64::from_be_bytes([
buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
])),
8,
))
}
SerialType::ConstInt0 => Ok((Value::Integer(0), 0)),
SerialType::ConstInt1 => Ok((Value::Integer(1), 0)),
SerialType::ConstInt0 => Ok((OwnedValue::Integer(0), 0)),
SerialType::ConstInt1 => Ok((OwnedValue::Integer(1), 0)),
SerialType::Blob(n) => {
if buf.len() < n {
return Err(anyhow!("Invalid Blob value"));
}
Ok((Value::Blob(buf[0..n].to_vec()), n))
Ok((OwnedValue::Blob(buf[0..n].to_vec()), n))
}
SerialType::String(n) => {
if buf.len() < n {
@ -404,7 +403,7 @@ pub fn read_value(buf: &[u8], serial_type: &SerialType) -> Result<(Value, usize)
}
let bytes = buf[0..n].to_vec();
let value = unsafe { String::from_utf8_unchecked(bytes) };
Ok((Value::Text(Rc::new(value)), n))
Ok((OwnedValue::Text(value), n))
}
}
}
@ -458,22 +457,22 @@ mod tests {
}
#[rstest]
#[case(&[], SerialType::Null, Value::Null)]
#[case(&[255], SerialType::UInt8, Value::Integer(255))]
#[case(&[0x12, 0x34], SerialType::BEInt16, Value::Integer(0x1234))]
#[case(&[0x12, 0x34, 0x56], SerialType::BEInt24, Value::Integer(0x123456))]
#[case(&[0x12, 0x34, 0x56, 0x78], SerialType::BEInt32, Value::Integer(0x12345678))]
#[case(&[0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC], SerialType::BEInt48, Value::Integer(0x123456789ABC))]
#[case(&[0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xFF], SerialType::BEInt64, Value::Integer(0x123456789ABCDEFF))]
#[case(&[64, 9, 33, 251, 84, 68, 45, 24], SerialType::BEFloat64, Value::Float(3.141592653589793))]
#[case(&[], SerialType::ConstInt0, Value::Integer(0))]
#[case(&[], SerialType::ConstInt1, Value::Integer(1))]
#[case(&[1, 2, 3], SerialType::Blob(3), Value::Blob(vec![1, 2, 3]))]
#[case(&[65, 66, 67], SerialType::String(3), Value::Text("ABC".to_string().into()))]
#[case(&[], SerialType::Null, OwnedValue::Null)]
#[case(&[255], SerialType::UInt8, OwnedValue::Integer(255))]
#[case(&[0x12, 0x34], SerialType::BEInt16, OwnedValue::Integer(0x1234))]
#[case(&[0x12, 0x34, 0x56], SerialType::BEInt24, OwnedValue::Integer(0x123456))]
#[case(&[0x12, 0x34, 0x56, 0x78], SerialType::BEInt32, OwnedValue::Integer(0x12345678))]
#[case(&[0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC], SerialType::BEInt48, OwnedValue::Integer(0x123456789ABC))]
#[case(&[0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xFF], SerialType::BEInt64, OwnedValue::Integer(0x123456789ABCDEFF))]
#[case(&[64, 9, 33, 251, 84, 68, 45, 24], SerialType::BEFloat64, OwnedValue::Float(3.141592653589793))]
#[case(&[], SerialType::ConstInt0, OwnedValue::Integer(0))]
#[case(&[], SerialType::ConstInt1, OwnedValue::Integer(1))]
#[case(&[1, 2, 3], SerialType::Blob(3), OwnedValue::Blob(vec![1, 2, 3]))]
#[case(&[65, 66, 67], SerialType::String(3), OwnedValue::Text("ABC".to_string()))]
fn test_read_value(
#[case] buf: &[u8],
#[case] serial_type: SerialType,
#[case] expected: Value,
#[case] expected: OwnedValue,
) {
let result = read_value(buf, &serial_type).unwrap();
assert_eq!(result, (expected, buf.len()));

View file

@ -1,16 +1,33 @@
use std::rc::Rc;
use anyhow::Result;
#[derive(Debug, Clone, PartialEq)]
pub enum Value {
pub enum Value<'a> {
Null,
Integer(i64),
Float(f64),
Text(Rc<String>),
Text(&'a String),
Blob(&'a Vec<u8>),
}
#[derive(Debug, Clone, PartialEq)]
pub enum OwnedValue {
Null,
Integer(i64),
Float(f64),
Text(String),
Blob(Vec<u8>),
}
pub fn to_value<'a>(value: &'a OwnedValue) -> Value<'a> {
match value {
OwnedValue::Null => Value::Null,
OwnedValue::Integer(i) => Value::Integer(*i),
OwnedValue::Float(f) => Value::Float(*f),
OwnedValue::Text(s) => Value::Text(s),
OwnedValue::Blob(b) => Value::Blob(b),
}
}
pub trait FromValue {
fn from_value(value: &Value) -> Result<Self>
where
@ -36,12 +53,22 @@ impl FromValue for String {
}
#[derive(Debug)]
pub struct Record {
pub values: Vec<Value>,
pub struct Record<'a> {
pub values: Vec<Value<'a>>,
}
impl Record {
pub fn new(values: Vec<Value>) -> Self {
impl<'a> Record<'a> {
pub fn new(values: Vec<Value<'a>>) -> Self {
Self { values }
}
}
pub struct OwnedRecord {
pub values: Vec<OwnedValue>,
}
impl OwnedRecord {
pub fn new(values: Vec<OwnedValue>) -> Self {
Self { values }
}
}

View file

@ -1,8 +1,9 @@
use crate::btree::{Cursor, CursorResult};
use crate::pager::Pager;
use crate::types::{Record, Value};
use crate::types::{OwnedValue, Record};
use anyhow::Result;
use std::cell::RefCell;
use std::collections::BTreeMap;
use std::sync::Arc;
@ -143,24 +144,24 @@ impl ProgramBuilder {
}
}
pub enum StepResult {
pub enum StepResult<'a> {
Done,
IO,
Row(Record),
Row(Record<'a>),
}
/// The program state describes the environment in which the program executes.
pub struct ProgramState {
pub pc: usize,
cursors: BTreeMap<usize, Cursor>,
registers: Vec<Value>,
cursors: RefCell<BTreeMap<usize, Cursor>>,
registers: Vec<OwnedValue>,
}
impl ProgramState {
pub fn new(max_registers: usize) -> Self {
let cursors = BTreeMap::new();
let cursors = RefCell::new(BTreeMap::new());
let mut registers = Vec::with_capacity(max_registers);
registers.resize(max_registers, Value::Null);
registers.resize(max_registers, OwnedValue::Null);
Self {
pc: 0,
cursors,
@ -191,10 +192,11 @@ impl Program {
}
}
pub fn step(&self, state: &mut ProgramState, pager: Arc<Pager>) -> Result<StepResult> {
pub fn step<'a>(&self, state: &'a mut ProgramState, pager: Arc<Pager>) -> Result<StepResult<'a>> {
loop {
let insn = &self.insns[state.pc];
trace_insn(state.pc, insn);
let mut cursors = state.cursors.borrow_mut();
match insn {
Insn::Init { target_pc } => {
state.pc = *target_pc;
@ -204,14 +206,14 @@ impl Program {
root_page,
} => {
let cursor = Cursor::new(pager.clone(), *root_page);
state.cursors.insert(*cursor_id, cursor);
cursors.insert(*cursor_id, cursor);
state.pc += 1;
}
Insn::OpenReadAwait => {
state.pc += 1;
}
Insn::RewindAsync { cursor_id } => {
let cursor = state.cursors.get_mut(cursor_id).unwrap();
let cursor = cursors.get_mut(cursor_id).unwrap();
match cursor.rewind()? {
CursorResult::Ok(()) => {}
CursorResult::IO => {
@ -225,7 +227,7 @@ impl Program {
cursor_id,
pc_if_empty,
} => {
let cursor = state.cursors.get_mut(cursor_id).unwrap();
let cursor = cursors.get_mut(cursor_id).unwrap();
cursor.wait_for_completion()?;
if cursor.is_empty() {
state.pc = *pc_if_empty;
@ -238,7 +240,7 @@ impl Program {
column,
dest,
} => {
let cursor = state.cursors.get_mut(cursor_id).unwrap();
let cursor = cursors.get_mut(cursor_id).unwrap();
if let Some(ref record) = *cursor.record()? {
state.registers[*dest] = record.values[*column].clone();
} else {
@ -252,13 +254,13 @@ impl Program {
} => {
let mut values = Vec::with_capacity(*register_end - *register_start);
for i in *register_start..*register_end {
values.push(state.registers[i].clone());
values.push(crate::types::to_value(&state.registers[i]));
}
state.pc += 1;
return Ok(StepResult::Row(Record::new(values)));
}
Insn::NextAsync { cursor_id } => {
let cursor = state.cursors.get_mut(cursor_id).unwrap();
let cursor = cursors.get_mut(cursor_id).unwrap();
match cursor.next()? {
CursorResult::Ok(_) => {}
CursorResult::IO => {
@ -272,7 +274,7 @@ impl Program {
cursor_id,
pc_if_next,
} => {
let cursor = state.cursors.get_mut(cursor_id).unwrap();
let cursor = cursors.get_mut(cursor_id).unwrap();
cursor.wait_for_completion()?;
if cursor.has_record() {
state.pc = *pc_if_next;
@ -290,22 +292,22 @@ impl Program {
state.pc = *target_pc;
}
Insn::Integer { value, dest } => {
state.registers[*dest] = Value::Integer(*value);
state.registers[*dest] = OwnedValue::Integer(*value);
state.pc += 1;
}
Insn::RowId { cursor_id, dest } => {
let cursor = state.cursors.get_mut(cursor_id).unwrap();
let cursor = cursors.get_mut(cursor_id).unwrap();
if let Some(ref rowid) = *cursor.rowid()? {
state.registers[*dest] = Value::Integer(*rowid as i64);
state.registers[*dest] = OwnedValue::Integer(*rowid as i64);
} else {
todo!();
}
state.pc += 1;
}
Insn::DecrJumpZero { reg, target_pc } => match state.registers[*reg] {
Value::Integer(n) => {
OwnedValue::Integer(n) => {
if n > 0 {
state.registers[*reg] = Value::Integer(n - 1);
state.registers[*reg] = OwnedValue::Integer(n - 1);
state.pc += 1;
} else {
state.pc = *target_pc;