Remove allocations from numeric text casting, cleanups

This commit is contained in:
PThorpe92 2025-02-24 12:30:38 -05:00
parent 7e94a152a5
commit 6d55cdba3b
No known key found for this signature in database
GPG key ID: 66DB3FBACBDD05CC
3 changed files with 88 additions and 519 deletions

View file

@ -1,13 +1,22 @@
use core::num::IntErrorKind;
use limbo_sqlite3_parser::ast::{self, CreateTableBody, Expr, FunctionTail, Literal};
use std::{rc::Rc, sync::Arc};
use crate::{
schema::{self, Column, Schema, Type},
types::OwnedValue,
LimboError, OpenFlags, Result, Statement, StepResult, IO,
};
pub trait RoundToPrecision {
fn round_to_precision(self, precision: f64) -> f64;
}
impl RoundToPrecision for f64 {
fn round_to_precision(self, precision: f64) -> f64 {
let factor = 10f64.powf(precision);
(self * factor).round() / factor
}
}
// https://sqlite.org/lang_keywords.html
const QUOTE_PAIRS: &[(char, char)] = &[('"', '"'), ('[', ']'), ('`', '`')];
@ -604,184 +613,6 @@ pub fn decode_percent(uri: &str) -> String {
String::from_utf8_lossy(&decoded).to_string()
}
#[derive(Debug, PartialEq)]
/// Reference:
/// https://github.com/sqlite/sqlite/blob/master/src/util.c#L798
pub enum CastTextToIntResultCode {
NotInt = -1,
Success = 0,
ExcessSpace = 1,
TooLargeOrMalformed = 2,
#[allow(dead_code)]
SpecialCase = 3,
}
pub fn text_to_integer(text: &str) -> (OwnedValue, CastTextToIntResultCode) {
let text = text.trim();
if text.is_empty() {
return (OwnedValue::Integer(0), CastTextToIntResultCode::NotInt);
}
let mut accum = String::new();
let mut sign = false;
let mut has_digit = false;
let mut excess_space = false;
let chars = text.chars();
for c in chars {
match c {
'0'..='9' => {
has_digit = true;
accum.push(c);
}
'+' | '-' if !has_digit && !sign => {
sign = true;
accum.push(c);
}
_ => {
excess_space = true;
break;
}
}
}
match accum.parse::<i64>() {
Ok(num) => {
if excess_space {
return (
OwnedValue::Integer(num),
CastTextToIntResultCode::ExcessSpace,
);
}
return (OwnedValue::Integer(num), CastTextToIntResultCode::Success);
}
Err(e) => match e.kind() {
IntErrorKind::NegOverflow | IntErrorKind::PosOverflow => (
OwnedValue::Integer(0),
CastTextToIntResultCode::TooLargeOrMalformed,
),
_ => (OwnedValue::Integer(0), CastTextToIntResultCode::NotInt),
},
}
}
#[derive(Debug, PartialEq)]
/// Reference
/// https://github.com/sqlite/sqlite/blob/master/src/util.c#L529
pub enum CastTextToRealResultCode {
PureInt = 1,
HasDecimal = 2,
NotValid = 0,
NotValidButPrefix = -1,
}
pub fn text_to_real(text: &str) -> (OwnedValue, CastTextToRealResultCode) {
let text = text.trim();
if text.is_empty() {
return (OwnedValue::Float(0.0), CastTextToRealResultCode::NotValid);
}
let mut accum = String::new();
let mut has_decimal_separator = false;
let mut sign = false;
let mut exp_sign = false;
let mut has_exponent = false;
let mut has_digit = false;
let mut has_decimal_digit = false;
let mut excess_space = false;
let mut chars = text.chars();
'outer: while let Some(c) = chars.next() {
match c {
'0'..='9' if !has_decimal_separator => {
has_digit = true;
accum.push(c);
}
'0'..='9' => {
// This pattern is used for both decimal and exponent digits
has_decimal_digit = true;
accum.push(c);
}
'+' | '-' if !has_digit && !sign => {
sign = true;
accum.push(c);
}
'.' if !has_decimal_separator => {
// Check if next char is a number
if let Some(ch) = chars.next() {
match ch {
'0'..='9' => {
has_decimal_separator = true;
accum.push(c);
accum.push(ch);
}
_ => {
excess_space = true;
break;
}
}
} else {
excess_space = true;
}
}
'E' | 'e' if !has_exponent && (!has_decimal_separator || has_decimal_digit) => {
// Lookahead if next char is a number or sign
let mut curr_sign = None;
loop {
if let Some(ch) = chars.next() {
match ch {
'0'..='9' => {
has_exponent = true;
accum.push(c);
if let Some(sign) = curr_sign {
exp_sign = true;
accum.push(sign);
}
accum.push(ch);
break;
}
'+' | '-' => {
curr_sign = Some(ch);
}
_ => {
excess_space = true;
break 'outer;
}
}
} else {
excess_space = true;
break 'outer;
}
}
}
_ => {
excess_space = true;
break;
}
}
}
if let Ok(num) = accum.parse::<f64>() {
if !has_decimal_separator && !exp_sign && !has_exponent && !excess_space {
return (OwnedValue::Float(num), CastTextToRealResultCode::PureInt);
}
if excess_space {
// TODO see if this branch satisfies: not a valid number, but has a valid prefix which
// includes a decimal point and/or an eNNN clause
return (
OwnedValue::Float(num),
CastTextToRealResultCode::NotValidButPrefix,
);
}
return (OwnedValue::Float(num), CastTextToRealResultCode::HasDecimal);
}
return (OwnedValue::Float(0.0), CastTextToRealResultCode::NotValid);
}
#[cfg(test)]
pub mod tests {
use super::*;
@ -1332,264 +1163,4 @@ pub mod tests {
"/home/user/db.sqlite"
);
}
#[test]
fn test_text_to_integer() {
assert_eq!(
text_to_integer("1"),
(OwnedValue::Integer(1), CastTextToIntResultCode::Success),
);
assert_eq!(
text_to_integer("-1"),
(OwnedValue::Integer(-1), CastTextToIntResultCode::Success),
);
assert_eq!(
text_to_integer("10000000"),
(
OwnedValue::Integer(10000000),
CastTextToIntResultCode::Success,
),
);
assert_eq!(
text_to_integer("-10000000"),
(
OwnedValue::Integer(-10000000),
CastTextToIntResultCode::Success,
),
);
assert_eq!(
text_to_integer("xxx"),
(OwnedValue::Integer(0), CastTextToIntResultCode::NotInt),
);
assert_eq!(
text_to_integer("123xxx"),
(
OwnedValue::Integer(123),
CastTextToIntResultCode::ExcessSpace,
),
);
assert_eq!(
text_to_integer("9223372036854775807"),
(
OwnedValue::Integer(i64::MAX),
CastTextToIntResultCode::Success,
),
);
assert_eq!(
text_to_integer("9223372036854775808"),
(
OwnedValue::Integer(0),
CastTextToIntResultCode::TooLargeOrMalformed,
),
);
assert_eq!(
text_to_integer("-9223372036854775808"),
(
OwnedValue::Integer(i64::MIN),
CastTextToIntResultCode::Success,
),
);
assert_eq!(
text_to_integer("-9223372036854775809"),
(
OwnedValue::Integer(0),
CastTextToIntResultCode::TooLargeOrMalformed,
),
);
assert_eq!(
text_to_integer("-"),
(OwnedValue::Integer(0), CastTextToIntResultCode::NotInt,),
);
}
#[test]
fn test_text_to_real() {
assert_eq!(
text_to_real("1"),
(OwnedValue::Float(1.0), CastTextToRealResultCode::PureInt),
);
assert_eq!(
text_to_real("-1"),
(OwnedValue::Float(-1.0), CastTextToRealResultCode::PureInt),
);
assert_eq!(
text_to_real("1.0"),
(OwnedValue::Float(1.0), CastTextToRealResultCode::HasDecimal),
);
assert_eq!(
text_to_real("-1.0"),
(
OwnedValue::Float(-1.0),
CastTextToRealResultCode::HasDecimal,
),
);
assert_eq!(
text_to_real("1e10"),
(
OwnedValue::Float(1e10),
CastTextToRealResultCode::HasDecimal,
),
);
assert_eq!(
text_to_real("-1e10"),
(
OwnedValue::Float(-1e10),
CastTextToRealResultCode::HasDecimal,
),
);
assert_eq!(
text_to_real("1e-10"),
(
OwnedValue::Float(1e-10),
CastTextToRealResultCode::HasDecimal,
),
);
assert_eq!(
text_to_real("-1e-10"),
(
OwnedValue::Float(-1e-10),
CastTextToRealResultCode::HasDecimal,
),
);
assert_eq!(
text_to_real("1.123e10"),
(
OwnedValue::Float(1.123e10),
CastTextToRealResultCode::HasDecimal,
),
);
assert_eq!(
text_to_real("-1.123e10"),
(
OwnedValue::Float(-1.123e10),
CastTextToRealResultCode::HasDecimal,
),
);
assert_eq!(
text_to_real("1.123e-10"),
(
OwnedValue::Float(1.123e-10),
CastTextToRealResultCode::HasDecimal,
),
);
assert_eq!(
text_to_real("-1.123e-10"),
(
OwnedValue::Float(-1.123e-10),
CastTextToRealResultCode::HasDecimal,
),
);
assert_eq!(
text_to_real("1-282584294928"),
(
OwnedValue::Float(1.0),
CastTextToRealResultCode::NotValidButPrefix
),
);
assert_eq!(
text_to_real("xxx"),
(OwnedValue::Float(0.0), CastTextToRealResultCode::NotValid),
);
assert_eq!(
text_to_real("1.7976931348623157e308"),
(
OwnedValue::Float(f64::MAX),
CastTextToRealResultCode::HasDecimal,
),
);
assert_eq!(
text_to_real("1.7976931348623157e309"),
(
OwnedValue::Float(f64::INFINITY),
CastTextToRealResultCode::HasDecimal,
),
);
assert_eq!(
text_to_real("-1.7976931348623157e308"),
(
OwnedValue::Float(f64::MIN),
CastTextToRealResultCode::HasDecimal,
),
);
assert_eq!(
text_to_real("-1.7976931348623157e309"),
(
OwnedValue::Float(f64::NEG_INFINITY),
CastTextToRealResultCode::HasDecimal,
),
);
assert_eq!(
text_to_real("1E"),
(
OwnedValue::Float(1.0),
CastTextToRealResultCode::NotValidButPrefix,
),
);
assert_eq!(
text_to_real("1EE"),
(
OwnedValue::Float(1.0),
CastTextToRealResultCode::NotValidButPrefix,
),
);
assert_eq!(
text_to_real("-1E"),
(
OwnedValue::Float(-1.0),
CastTextToRealResultCode::NotValidButPrefix,
),
);
assert_eq!(
text_to_real("1."),
(
OwnedValue::Float(1.0),
CastTextToRealResultCode::NotValidButPrefix,
),
);
assert_eq!(
text_to_real("-1."),
(
OwnedValue::Float(-1.0),
CastTextToRealResultCode::NotValidButPrefix,
),
);
assert_eq!(
text_to_real("1.23E"),
(
OwnedValue::Float(1.23),
CastTextToRealResultCode::NotValidButPrefix,
),
);
assert_eq!(
text_to_real("1.23E-"),
(
OwnedValue::Float(1.23),
CastTextToRealResultCode::NotValidButPrefix,
),
);
assert_eq!(
text_to_real("0"),
(OwnedValue::Float(0.0), CastTextToRealResultCode::PureInt,),
);
assert_eq!(
text_to_real("-0"),
(OwnedValue::Float(-0.0), CastTextToRealResultCode::PureInt,),
);
assert_eq!(
text_to_real("-0"),
(OwnedValue::Float(0.0), CastTextToRealResultCode::PureInt,),
);
assert_eq!(
text_to_real("-0.0"),
(OwnedValue::Float(0.0), CastTextToRealResultCode::HasDecimal,),
);
assert_eq!(
text_to_real("0.0"),
(OwnedValue::Float(0.0), CastTextToRealResultCode::HasDecimal,),
);
assert_eq!(
text_to_real("-"),
(OwnedValue::Float(0.0), CastTextToRealResultCode::NotValid,),
);
}
}

View file

@ -3,6 +3,7 @@ use std::num::NonZero;
use super::{cast_text_to_numeric, AggFunc, BranchOffset, CursorID, FuncCtx, PageIdx};
use crate::storage::wal::CheckpointMode;
use crate::types::{OwnedValue, Record};
use crate::util::RoundToPrecision;
use limbo_macros::Description;
macro_rules! final_agg_values {
@ -712,7 +713,7 @@ pub fn exec_add(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
}
}
(OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => {
OwnedValue::Float((lhs + rhs).round_to_precision(6))
OwnedValue::Float((lhs + rhs).round_to_precision(6.0))
}
(OwnedValue::Float(f), OwnedValue::Integer(i))
| (OwnedValue::Integer(i), OwnedValue::Float(f)) => OwnedValue::Float(*f + *i as f64),
@ -768,7 +769,7 @@ pub fn exec_multiply(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
}
}
(OwnedValue::Float(lhs), OwnedValue::Float(rhs)) => {
OwnedValue::Float((lhs * rhs).round_to_precision(6))
OwnedValue::Float((lhs * rhs).round_to_precision(6.0))
}
(OwnedValue::Integer(i), OwnedValue::Float(f))
| (OwnedValue::Float(f), OwnedValue::Integer(i)) => OwnedValue::Float(*i as f64 * { *f }),
@ -1083,17 +1084,6 @@ pub fn exec_or(mut lhs: &OwnedValue, mut rhs: &OwnedValue) -> OwnedValue {
}
}
trait RoundToPrecision {
fn round_to_precision(self, precision: i32) -> f64;
}
impl RoundToPrecision for f64 {
fn round_to_precision(self, precision: i32) -> f64 {
let factor = 10f64.powi(precision);
(self * factor).round() / factor
}
}
#[cfg(test)]
mod tests {
use crate::{

View file

@ -39,12 +39,10 @@ use crate::storage::wal::CheckpointResult;
use crate::storage::{btree::BTreeCursor, pager::Pager};
use crate::translate::plan::{ResultSetColumn, TableReference};
use crate::types::{
AggContext, Cursor, CursorResult, ExternalAggState, OwnedValue, Record, SeekKey, SeekOp,
};
use crate::util::{
parse_schema_rows, text_to_integer, text_to_real, CastTextToIntResultCode,
CastTextToRealResultCode,
AggContext, Cursor, CursorResult, ExternalAggState, OwnedValue, OwnedValueType, Record,
SeekKey, SeekOp,
};
use crate::util::{parse_schema_rows, RoundToPrecision};
use crate::vdbe::builder::CursorType;
use crate::vdbe::insn::Insn;
use crate::vector::{vector32, vector64, vector_distance_cos, vector_extract};
@ -1179,20 +1177,18 @@ impl Program {
} else {
conn.auto_commit.replace(*auto_commit);
}
} else if !*auto_commit {
return Err(LimboError::TxError(
"cannot start a transaction within a transaction".to_string(),
));
} else if *rollback {
return Err(LimboError::TxError(
"cannot rollback - no transaction is active".to_string(),
));
} else {
if !*auto_commit {
return Err(LimboError::TxError(
"cannot start a transaction within a transaction".to_string(),
));
} else if *rollback {
return Err(LimboError::TxError(
"cannot rollback - no transaction is active".to_string(),
));
} else {
return Err(LimboError::TxError(
"cannot commit - no transaction is active".to_string(),
));
}
return Err(LimboError::TxError(
"cannot commit - no transaction is active".to_string(),
));
}
return self.halt(pager);
}
@ -2040,7 +2036,7 @@ impl Program {
unreachable!("Cast with non-text type");
};
let result =
exec_cast(&reg_value_argument, &reg_value_type.as_str());
exec_cast(&reg_value_argument, reg_value_type.as_str());
state.registers[*dest] = result;
}
ScalarFunc::Changes => {
@ -2078,8 +2074,8 @@ impl Program {
};
OwnedValue::Integer(exec_glob(
cache,
&pattern.as_str(),
&text.as_str(),
pattern.as_str(),
text.as_str(),
)
as i64)
}
@ -2110,12 +2106,12 @@ impl Program {
let match_expression = &state.registers[*start_reg + 1];
let pattern = match pattern {
OwnedValue::Text(_) => pattern.clone(),
_ => exec_cast(pattern, "TEXT"),
OwnedValue::Text(_) => pattern,
_ => &exec_cast(pattern, "TEXT"),
};
let match_expression = match match_expression {
OwnedValue::Text(_) => match_expression.clone(),
_ => exec_cast(match_expression, "TEXT"),
OwnedValue::Text(_) => match_expression,
_ => &exec_cast(match_expression, "TEXT"),
};
let result = match (pattern, match_expression) {
@ -2131,8 +2127,8 @@ impl Program {
};
OwnedValue::Integer(exec_like_with_escape(
&pattern.as_str(),
&match_expression.as_str(),
pattern.as_str(),
match_expression.as_str(),
escape,
)
as i64)
@ -2148,14 +2144,14 @@ impl Program {
};
OwnedValue::Integer(exec_like(
cache,
&pattern.as_str(),
&match_expression.as_str(),
pattern.as_str(),
match_expression.as_str(),
)
as i64)
}
(OwnedValue::Null, OwnedValue::Null)
| (OwnedValue::Null, _)
| (_, OwnedValue::Null) => OwnedValue::Null,
(OwnedValue::Null, _) | (_, OwnedValue::Null) => {
OwnedValue::Null
}
_ => {
unreachable!("Like failed");
}
@ -2825,7 +2821,7 @@ impl Program {
.expect("only weak ref to connection?");
let auto_commit = *connection.auto_commit.borrow();
tracing::trace!("Halt auto_commit {}", auto_commit);
return if auto_commit {
if auto_commit {
let current_state = connection.transaction_state.borrow().clone();
if current_state == TransactionState::Read {
pager.end_read_tx()?;
@ -2849,8 +2845,8 @@ impl Program {
conn.set_changes(self.n_change.get());
}
}
return Ok(StepResult::Done);
};
Ok(StepResult::Done)
}
}
}
@ -3451,30 +3447,35 @@ fn exec_unicode(reg: &OwnedValue) -> OwnedValue {
fn _to_float(reg: &OwnedValue) -> f64 {
match reg {
OwnedValue::Text(x) => x.as_str().parse().unwrap_or(0.0),
OwnedValue::Text(x) => match cast_text_to_numeric(x.as_str()) {
OwnedValue::Integer(i) => i as f64,
OwnedValue::Float(f) => f,
_ => unreachable!(),
},
OwnedValue::Integer(x) => *x as f64,
OwnedValue::Float(x) => *x,
OwnedValue::Agg(ctx) => _to_float(ctx.final_value()),
_ => 0.0,
}
}
fn exec_round(reg: &OwnedValue, precision: Option<OwnedValue>) -> OwnedValue {
let precision = match precision {
Some(OwnedValue::Text(x)) => x.as_str().parse().unwrap_or(0.0),
Some(OwnedValue::Integer(x)) => x as f64,
Some(OwnedValue::Float(x)) => x,
Some(OwnedValue::Null) => return OwnedValue::Null,
_ => 0.0,
let reg = _to_float(reg);
let round = |reg: f64, f: f64| {
let precision = if f < 1.0 { 0.0 } else { f };
OwnedValue::Float(reg.round_to_precision(precision))
};
let reg = match reg {
OwnedValue::Agg(ctx) => _to_float(ctx.final_value()),
_ => _to_float(reg),
};
let precision = if precision < 1.0 { 0.0 } else { precision };
let multiplier = 10f64.powi(precision as i32);
OwnedValue::Float(((reg * multiplier).round()) / multiplier)
match precision {
Some(OwnedValue::Text(x)) => match cast_text_to_numeric(x.as_str()) {
OwnedValue::Integer(i) => round(reg, i as f64),
OwnedValue::Float(f) => round(reg, f),
_ => unreachable!(),
},
Some(OwnedValue::Integer(i)) => round(reg, i as f64),
Some(OwnedValue::Float(f)) => round(reg, f),
None => round(reg, 0.0),
_ => OwnedValue::Null,
}
}
// Implements TRIM pattern matching.
@ -3566,9 +3567,9 @@ fn exec_cast(value: &OwnedValue, datatype: &str) -> OwnedValue {
OwnedValue::Blob(b) => {
// Convert BLOB to TEXT first
let text = String::from_utf8_lossy(b);
cast_text_to_real(&text).0
cast_text_to_real(&text)
}
OwnedValue::Text(t) => cast_text_to_real(t.as_str()).0,
OwnedValue::Text(t) => cast_text_to_real(t.as_str()),
OwnedValue::Integer(i) => OwnedValue::Float(*i as f64),
OwnedValue::Float(f) => OwnedValue::Float(*f),
_ => OwnedValue::Float(0.0),
@ -3577,9 +3578,9 @@ fn exec_cast(value: &OwnedValue, datatype: &str) -> OwnedValue {
OwnedValue::Blob(b) => {
// Convert BLOB to TEXT first
let text = String::from_utf8_lossy(b);
cast_text_to_integer(&text).0
cast_text_to_integer(&text)
}
OwnedValue::Text(t) => cast_text_to_integer(t.as_str()).0,
OwnedValue::Text(t) => cast_text_to_integer(t.as_str()),
OwnedValue::Integer(i) => OwnedValue::Integer(*i),
// A cast of a REAL value into an INTEGER results in the integer between the REAL value and zero
// that is closest to the REAL value. If a REAL is greater than the greatest possible signed integer (+9223372036854775807)
@ -3677,14 +3678,14 @@ fn cast_text_to_integer(text: &str) -> OwnedValue {
/// the TEXT value are ignored when converging from TEXT to REAL.
/// If there is no prefix that can be interpreted as a real number, the result of the conversion is 0.0.
fn cast_text_to_real(text: &str) -> OwnedValue {
let trimmed = text.trim_start();
let trimmed = text.trim();
if trimmed.is_empty() {
return OwnedValue::Float(0.0);
}
if let Ok(num) = trimmed.parse::<f64>() {
return OwnedValue::Float(num);
}
let Ok((_, _, text)) = parse_numeric_str(trimmed) else {
let Ok((_, text)) = parse_numeric_str(trimmed) else {
return OwnedValue::Float(0.0);
};
text.parse::<f64>()
@ -3705,19 +3706,19 @@ pub fn checked_cast_text_to_numeric(text: &str) -> std::result::Result<OwnedValu
// sqlite will parse the first N digits of a string to numeric value, then determine
// whether _that_ value is more likely a real or integer value. e.g.
// '-100234-2344.23e14' evaluates to -100234 instead of -100234.0
let (has_decimal, has_exponent, text) = parse_numeric_str(text)?;
if !has_decimal && !has_exponent {
Ok(text
let (kind, text) = parse_numeric_str(text)?;
match kind {
OwnedValueType::Integer => Ok(text
.parse::<i64>()
.map_or(OwnedValue::Integer(0), OwnedValue::Integer))
} else {
Ok(text
.map_or(OwnedValue::Integer(0), OwnedValue::Integer)),
OwnedValueType::Float => Ok(text
.parse::<f64>()
.map_or(OwnedValue::Float(0.0), OwnedValue::Float))
.map_or(OwnedValue::Float(0.0), OwnedValue::Float)),
_ => unreachable!(),
}
}
fn parse_numeric_str(text: &str) -> Result<(bool, bool, &str), ()> {
fn parse_numeric_str(text: &str) -> Result<(OwnedValueType, &str), ()> {
let bytes = text.trim_start().as_bytes();
let mut end = 0;
let mut has_decimal = false;
@ -3746,7 +3747,14 @@ fn parse_numeric_str(text: &str) -> Result<(bool, bool, &str), ()> {
if end == 0 || (end == 1 && bytes[0] == b'-') {
return Err(());
}
Ok((has_decimal, has_exponent, &text[..end]))
Ok((
if !has_decimal && !has_exponent {
OwnedValueType::Integer
} else {
OwnedValueType::Float
},
&text[..end],
))
}
fn cast_text_to_numeric(txt: &str) -> OwnedValue {