1. Fix merge conflicts

2. change tests for extensions to return error instead of null (Preston)
This commit is contained in:
Krishna Vishal 2025-01-19 03:00:18 +05:30
parent ca097b1972
commit 6173aeeb3b
5 changed files with 117 additions and 37 deletions

View file

@ -3,6 +3,8 @@ use std::fmt;
use std::fmt::{Debug, Display};
use std::rc::Rc;
use crate::LimboError;
pub struct ExternalFunc {
pub name: String,
pub func: ExtFunc,
@ -102,6 +104,7 @@ impl Display for JsonFunc {
pub enum AggFunc {
Avg,
Count,
Count0,
GroupConcat,
Max,
Min,
@ -129,9 +132,25 @@ impl PartialEq for AggFunc {
}
impl AggFunc {
pub fn num_args(&self) -> usize {
match self {
Self::Avg => 1,
Self::Count0 => 0,
Self::Count => 1,
Self::GroupConcat => 1,
Self::Max => 1,
Self::Min => 1,
Self::StringAgg => 2,
Self::Sum => 1,
Self::Total => 1,
Self::External(func) => func.agg_args().unwrap_or(0),
}
}
pub fn to_string(&self) -> &str {
match self {
Self::Avg => "avg",
Self::Count0 => "count",
Self::Count => "count",
Self::GroupConcat => "group_concat",
Self::Max => "max",
@ -390,19 +409,64 @@ pub struct FuncCtx {
}
impl Func {
pub fn resolve_function(name: &str, arg_count: usize) -> Result<Self, ()> {
pub fn resolve_function(name: &str, arg_count: usize) -> Result<Self, LimboError> {
match name {
"avg" => Ok(Self::Agg(AggFunc::Avg)),
"count" => Ok(Self::Agg(AggFunc::Count)),
"group_concat" => Ok(Self::Agg(AggFunc::GroupConcat)),
"max" if arg_count == 0 || arg_count == 1 => Ok(Self::Agg(AggFunc::Max)),
"avg" => {
if arg_count != 1 {
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
}
Ok(Self::Agg(AggFunc::Avg))
}
"count" => {
// Handle both COUNT() and COUNT(expr) cases
if arg_count == 0 {
Ok(Self::Agg(AggFunc::Count0)) // COUNT() case
} else if arg_count == 1 {
Ok(Self::Agg(AggFunc::Count)) // COUNT(expr) case
} else {
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
}
}
"group_concat" => {
if arg_count != 1 && arg_count != 2 {
println!("{}", arg_count);
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
}
Ok(Self::Agg(AggFunc::GroupConcat))
}
"max" if arg_count > 1 => Ok(Self::Scalar(ScalarFunc::Max)),
"min" if arg_count == 0 || arg_count == 1 => Ok(Self::Agg(AggFunc::Min)),
"max" => {
if arg_count < 1 {
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
}
Ok(Self::Agg(AggFunc::Max))
}
"min" if arg_count > 1 => Ok(Self::Scalar(ScalarFunc::Min)),
"min" => {
if arg_count < 1 {
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
}
Ok(Self::Agg(AggFunc::Min))
}
"nullif" if arg_count == 2 => Ok(Self::Scalar(ScalarFunc::Nullif)),
"string_agg" => Ok(Self::Agg(AggFunc::StringAgg)),
"sum" => Ok(Self::Agg(AggFunc::Sum)),
"total" => Ok(Self::Agg(AggFunc::Total)),
"string_agg" => {
if arg_count != 2 {
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
}
Ok(Self::Agg(AggFunc::StringAgg))
}
"sum" => {
if arg_count != 1 {
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
}
Ok(Self::Agg(AggFunc::Sum))
}
"total" => {
if arg_count != 1 {
crate::bail_parse_error!("wrong number of arguments to function {}()", name)
}
Ok(Self::Agg(AggFunc::Total))
}
"char" => Ok(Self::Scalar(ScalarFunc::Char)),
"coalesce" => Ok(Self::Scalar(ScalarFunc::Coalesce)),
"concat" => Ok(Self::Scalar(ScalarFunc::Concat)),
@ -486,7 +550,7 @@ impl Func {
"trunc" => Ok(Self::Math(MathFunc::Trunc)),
#[cfg(not(target_family = "wasm"))]
"load_extension" => Ok(Self::Scalar(ScalarFunc::LoadExtension)),
_ => Err(()),
_ => crate::bail_parse_error!("no such function: {}", name),
}
}
}
}

View file

@ -74,7 +74,7 @@ pub fn translate_aggregation_step(
});
target_register
}
AggFunc::Count => {
AggFunc::Count | AggFunc::Count0 => {
let expr_reg = if agg.args.is_empty() {
program.alloc_register()
} else {
@ -87,7 +87,11 @@ pub fn translate_aggregation_step(
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Count,
func: if matches!(agg.func, AggFunc::Count0) {
AggFunc::Count0
} else {
AggFunc::Count
},
});
target_register
}

View file

@ -463,14 +463,18 @@ pub fn translate_aggregation_step_groupby(
});
target_register
}
AggFunc::Count => {
AggFunc::Count | AggFunc::Count0 => {
let expr_reg = program.alloc_register();
emit_column(program, expr_reg);
program.emit_insn(Insn::AggStep {
acc_reg: target_register,
col: expr_reg,
delimiter: 0,
func: AggFunc::Count,
func: if matches!(agg.func, AggFunc::Count0) {
AggFunc::Count0
} else {
AggFunc::Count
},
});
target_register
}

View file

@ -116,21 +116,20 @@ pub fn prepare_select_plan(
args_count,
) {
Ok(Func::Agg(f)) => {
let agg_args: Result<Vec<Expr>, LimboError> = match args {
// if args is None and its COUNT
None if name.0.to_uppercase() == "COUNT" => {
let count_args = vec![ast::Expr::Literal(
ast::Literal::Numeric("1".to_string()),
)];
Ok(count_args)
}
// if args is None and the function is not COUNT
None => crate::bail_parse_error!(
"Aggregate function {} requires arguments",
name.0
),
Some(args) => Ok(args.clone()),
};
let agg_args: Result<Vec<Expr>, LimboError> =
match (args, &f) {
(None, crate::function::AggFunc::Count0) => {
// COUNT() case
Ok(vec![ast::Expr::Literal(
ast::Literal::Numeric("1".to_string()),
)])
}
(None, _) => crate::bail_parse_error!(
"Aggregate function {} requires arguments",
name.0
),
(Some(args), _) => Ok(args.clone()),
};
let agg = Aggregate {
func: f,
@ -163,7 +162,7 @@ pub fn prepare_select_plan(
contains_aggregates,
});
}
Err(_) => {
Err(e) => {
if let Some(f) = syms.resolve_function(&name.0, args_count)
{
if let ExtFunc::Scalar(_) = f.as_ref().func {
@ -199,6 +198,9 @@ pub fn prepare_select_plan(
contains_aggregates: true,
});
}
continue; // Continue with the normal flow instead of returning
} else {
return Err(e);
}
}
}
@ -333,4 +335,4 @@ pub fn prepare_select_plan(
}
_ => todo!(),
}
}
}

View file

@ -1204,7 +1204,7 @@ impl Program {
// Total() never throws an integer overflow.
OwnedValue::Agg(Box::new(AggContext::Sum(OwnedValue::Float(0.0))))
}
AggFunc::Count => {
AggFunc::Count | AggFunc::Count0 => {
OwnedValue::Agg(Box::new(AggContext::Count(OwnedValue::Integer(0))))
}
AggFunc::Max => {
@ -1289,7 +1289,13 @@ impl Program {
};
*acc += col;
}
AggFunc::Count => {
AggFunc::Count | AggFunc::Count0 => {
// println!("here");
if matches!(&state.registers[*acc_reg], OwnedValue::Null) {
state.registers[*acc_reg] = OwnedValue::Agg(Box::new(
AggContext::Count(OwnedValue::Integer(0)),
));
}
let OwnedValue::Agg(agg) = state.registers[*acc_reg].borrow_mut()
else {
unreachable!();
@ -1437,7 +1443,7 @@ impl Program {
*acc /= count.clone();
}
AggFunc::Sum | AggFunc::Total => {}
AggFunc::Count => {}
AggFunc::Count | AggFunc::Count0 => {}
AggFunc::Max => {}
AggFunc::Min => {}
AggFunc::GroupConcat | AggFunc::StringAgg => {}
@ -1451,7 +1457,7 @@ impl Program {
AggFunc::Total => {
state.registers[*register] = OwnedValue::Float(0.0);
}
AggFunc::Count => {
AggFunc::Count | AggFunc::Count0 => {
state.registers[*register] = OwnedValue::Integer(0);
}
_ => {}
@ -4209,4 +4215,4 @@ mod tests {
expected_str
);
}
}
}