refactor parsing/deserialization

This commit is contained in:
Nikita Sivukhin 2025-10-09 16:31:49 +04:00
parent a2f4376bd2
commit 8584ee18a3
3 changed files with 75 additions and 138 deletions

View file

@ -1,4 +1,5 @@
use crate::types::Value;
use crate::types::ValueType;
use crate::vdbe::Register;
use crate::LimboError;
use crate::Result;
@ -7,6 +8,26 @@ pub mod operations;
pub mod vector_types;
use vector_types::*;
pub fn parse_vector(value: &Register, vec_ty: Option<VectorType>) -> Result<Vector> {
match value.get_value().value_type() {
ValueType::Text => operations::text::vector_from_text(
vec_ty.unwrap_or(VectorType::Float32Dense),
value.get_value().to_text().expect("value must be text"),
),
ValueType::Blob => {
let Some(blob) = value.get_value().to_blob() else {
return Err(LimboError::ConversionError(
"Invalid vector value".to_string(),
));
};
Vector::from_blob(blob.to_vec())
}
_ => Err(LimboError::ConversionError(
"Invalid vector type".to_string(),
)),
}
}
pub fn vector32(args: &[Register]) -> Result<Value> {
if args.len() != 1 {
return Err(LimboError::ConversionError(
@ -47,8 +68,7 @@ pub fn vector_extract(args: &[Register]) -> Result<Value> {
return Ok(Value::build_text("[]"));
}
let vector_type = vector_type(blob)?;
let vector = vector_deserialize(vector_type, blob)?;
let vector = Vector::from_blob(blob.to_vec())?;
Ok(Value::build_text(operations::text::vector_to_text(&vector)))
}

View file

@ -87,10 +87,5 @@ pub fn vector_from_text(vector_type: VectorType, text: &str) -> Result<Vector> {
}
};
}
let dims = vector_type.size_to_dims(data.len());
Ok(Vector {
vector_type,
dims,
data,
})
Vector::from_data(vector_type, data)
}

View file

@ -1,6 +1,3 @@
use crate::types::ValueType;
use crate::vdbe::Register;
use crate::vector::operations;
use crate::{LimboError, Result};
#[derive(Debug, Clone, PartialEq, Copy)]
@ -9,15 +6,6 @@ pub enum VectorType {
Float64Dense,
}
impl VectorType {
pub fn size_to_dims(&self, size: usize) -> usize {
match self {
VectorType::Float32Dense => size / 4,
VectorType::Float64Dense => size / 8,
}
}
}
#[derive(Debug)]
pub struct Vector {
pub vector_type: VectorType,
@ -26,6 +14,55 @@ pub struct Vector {
}
impl Vector {
pub fn vector_type(mut blob: Vec<u8>) -> Result<(VectorType, Vec<u8>)> {
// Even-sized blobs are always float32.
if blob.len() % 2 == 0 {
return Ok((VectorType::Float32Dense, blob));
}
// Odd-sized blobs have type byte at the end
let vector_type = blob.pop().unwrap();
match vector_type {
1 => Ok((VectorType::Float32Dense, blob)),
2 => Ok((VectorType::Float64Dense, blob)),
_ => Err(LimboError::ConversionError(
"Invalid vector type".to_string(),
)),
}
}
pub fn from_blob(blob: Vec<u8>) -> Result<Self> {
let (vector_type, data) = Self::vector_type(blob)?;
Self::from_data(vector_type, data)
}
pub fn from_data(vector_type: VectorType, data: Vec<u8>) -> Result<Self> {
match vector_type {
VectorType::Float32Dense => {
if data.len() % 4 != 0 {
return Err(LimboError::InvalidArgument(format!(
"f32 dense vector unexpected data length: {}",
data.len(),
)));
}
Ok(Vector {
vector_type,
dims: data.len() / 4,
data,
})
}
VectorType::Float64Dense => {
if data.len() % 8 != 0 {
return Err(LimboError::InvalidArgument(format!(
"f64 dense vector unexpected data length: {}",
data.len(),
)));
}
Ok(Vector {
vector_type,
dims: data.len() / 8,
data,
})
}
}
}
/// # Safety
///
/// This method is used to reinterpret the underlying `Vec<u8>` data
@ -83,88 +120,6 @@ impl Vector {
}
}
pub fn parse_vector(value: &Register, vec_ty: Option<VectorType>) -> Result<Vector> {
match value.get_value().value_type() {
ValueType::Text => operations::text::vector_from_text(
vec_ty.unwrap_or(VectorType::Float32Dense),
value.get_value().to_text().expect("value must be text"),
),
ValueType::Blob => {
let Some(blob) = value.get_value().to_blob() else {
return Err(LimboError::ConversionError(
"Invalid vector value".to_string(),
));
};
let vector_type = vector_type(blob)?;
if let Some(vec_ty) = vec_ty {
if vec_ty != vector_type {
return Err(LimboError::ConversionError(
"Invalid vector type".to_string(),
));
}
}
vector_deserialize(vector_type, blob)
}
_ => Err(LimboError::ConversionError(
"Invalid vector type".to_string(),
)),
}
}
pub fn vector_deserialize(vector_type: VectorType, blob: &[u8]) -> Result<Vector> {
match vector_type {
VectorType::Float32Dense => vector_deserialize_f32(blob),
VectorType::Float64Dense => vector_deserialize_f64(blob),
}
}
pub fn vector_deserialize_f64(blob: &[u8]) -> Result<Vector> {
Ok(Vector {
vector_type: VectorType::Float64Dense,
dims: (blob.len() - 1) / 8,
data: blob[..blob.len() - 1].to_vec(),
})
}
pub fn vector_deserialize_f32(blob: &[u8]) -> Result<Vector> {
Ok(Vector {
vector_type: VectorType::Float32Dense,
dims: blob.len() / 4,
data: blob.to_vec(),
})
}
pub fn vector_type(blob: &[u8]) -> Result<VectorType> {
// Even-sized blobs are always float32.
if blob.len() % 2 == 0 {
return Ok(VectorType::Float32Dense);
}
// Odd-sized blobs have type byte at the end
let (data_blob, type_byte) = blob.split_at(blob.len() - 1);
let vector_type = type_byte[0];
match vector_type {
1 => {
if data_blob.len() % 4 != 0 {
return Err(LimboError::ConversionError(
"Invalid vector value".to_string(),
));
}
Ok(VectorType::Float32Dense)
}
2 => {
if data_blob.len() % 8 != 0 {
return Err(LimboError::ConversionError(
"Invalid vector value".to_string(),
));
}
Ok(VectorType::Float64Dense)
}
_ => Err(LimboError::ConversionError(
"Invalid vector type".to_string(),
)),
}
}
#[cfg(test)]
mod tests {
use crate::vector::operations;
@ -279,9 +234,9 @@ mod tests {
fn test_vector_type<const DIMS: usize>(v: Vector) -> bool {
let vtype = v.vector_type;
let value = operations::serialize::vector_serialize(v);
let blob = value.to_blob().unwrap();
match vector_type(blob) {
Ok(detected_type) => detected_type == vtype,
let blob = value.to_blob().unwrap().to_vec();
match Vector::vector_type(blob) {
Ok((detected_type, _)) => detected_type == vtype,
Err(_) => false,
}
}
@ -329,39 +284,6 @@ mod tests {
}
}
// Test size_to_dims calculation with different dimensions
#[quickcheck]
fn prop_size_to_dims_calculation_2d(v: ArbitraryVector<2>) -> bool {
test_size_to_dims::<2>(v.into())
}
#[quickcheck]
fn prop_size_to_dims_calculation_3d(v: ArbitraryVector<3>) -> bool {
test_size_to_dims::<3>(v.into())
}
#[quickcheck]
fn prop_size_to_dims_calculation_4d(v: ArbitraryVector<4>) -> bool {
test_size_to_dims::<4>(v.into())
}
#[quickcheck]
fn prop_size_to_dims_calculation_100d(v: ArbitraryVector<100>) -> bool {
test_size_to_dims::<100>(v.into())
}
#[quickcheck]
fn prop_size_to_dims_calculation_1536d(v: ArbitraryVector<1536>) -> bool {
test_size_to_dims::<1536>(v.into())
}
/// Test if the size_to_dims calculation is correct for a given vector.
fn test_size_to_dims<const DIMS: usize>(v: Vector) -> bool {
let size = v.data.len();
let calculated_dims = v.vector_type.size_to_dims(size);
calculated_dims == DIMS
}
#[quickcheck]
fn prop_vector_distance_safety_2d(v1: ArbitraryVector<2>, v2: ArbitraryVector<2>) -> bool {
test_vector_distance::<2>(&v1.into(), &v2.into())