fix complex Switch bug

This commit is contained in:
Folkert 2021-01-18 01:30:57 +01:00
parent 0b4af7e499
commit 4f4d555197
3 changed files with 99 additions and 24 deletions

View file

@ -1259,7 +1259,8 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
.build_extract_value( .build_extract_value(
argument, argument,
*index as u32, *index as u32,
env.arena.alloc(format!("struct_field_access_{}_", index)), env.arena
.alloc(format!("struct_field_access_single_element{}", index)),
) )
.unwrap() .unwrap()
} }
@ -1280,7 +1281,8 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
.build_extract_value( .build_extract_value(
argument, argument,
*index as u32, *index as u32,
env.arena.alloc(format!("struct_field_access_{}_", index)), env.arena
.alloc(format!("struct_field_access_record_{}", index)),
) )
.unwrap(), .unwrap(),
(StructValue(argument), Layout::Closure(_, _, _)) => env (StructValue(argument), Layout::Closure(_, _, _)) => env
@ -2119,8 +2121,13 @@ pub fn cast_basic_basic<'ctx>(
to_type: BasicTypeEnum<'ctx>, to_type: BasicTypeEnum<'ctx>,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
use inkwell::types::BasicType; use inkwell::types::BasicType;
// we can't use the more simple
// builder.build_bitcast(from_value, to_type, "cast_basic_basic")
// because this does not allow some (valid) bitcasts
// store the value in memory // store the value in memory
let argument_pointer = builder.build_alloca(from_value.get_type(), ""); let argument_pointer = builder.build_alloca(from_value.get_type(), "cast_alloca");
builder.build_store(argument_pointer, from_value); builder.build_store(argument_pointer, from_value);
// then read it back as a different type // then read it back as a different type
@ -2132,7 +2139,7 @@ pub fn cast_basic_basic<'ctx>(
) )
.into_pointer_value(); .into_pointer_value();
builder.build_load(to_type_pointer, "") builder.build_load(to_type_pointer, "cast_value")
} }
fn extract_tag_discriminant_struct<'a, 'ctx, 'env>( fn extract_tag_discriminant_struct<'a, 'ctx, 'env>(
@ -2155,15 +2162,12 @@ fn extract_tag_discriminant_ptr<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
from_value: PointerValue<'ctx>, from_value: PointerValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let ptr = cast_basic_basic( let tag_id_ptr_type = env.context.i64_type().ptr_type(AddressSpace::Generic);
env.builder,
from_value.into(), let ptr = env
env.context .builder
.i64_type() .build_bitcast(from_value, tag_id_ptr_type, "extract_tag_discriminant_ptr")
.ptr_type(AddressSpace::Generic) .into_pointer_value();
.into(),
)
.into_pointer_value();
env.builder.build_load(ptr, "load_tag_id").into_int_value() env.builder.build_load(ptr, "load_tag_id").into_int_value()
} }
@ -2252,9 +2256,17 @@ fn build_switch_ir<'a, 'ctx, 'env>(
Recursive(_) => { Recursive(_) => {
// we match on the discriminant, not the whole Tag // we match on the discriminant, not the whole Tag
cond_layout = Layout::Builtin(Builtin::Int64); cond_layout = Layout::Builtin(Builtin::Int64);
let full_cond_ptr = load_symbol(env, scope, cond_symbol).into_pointer_value();
extract_tag_discriminant_ptr(env, full_cond_ptr) use BasicValueEnum::*;
match load_symbol(env, scope, cond_symbol) {
PointerValue(full_cond_ptr) => {
extract_tag_discriminant_ptr(env, full_cond_ptr)
}
StructValue(full_cond_struct) => {
extract_tag_discriminant_struct(env, full_cond_struct)
}
_ => unreachable!(),
}
} }
NullableWrapped { nullable_id, .. } => { NullableWrapped { nullable_id, .. } => {
// we match on the discriminant, not the whole Tag // we match on the discriminant, not the whole Tag

View file

@ -1979,7 +1979,7 @@ mod gen_primitives {
} }
#[test] #[test]
fn nullable_eval() { fn nullable_eval_cfold() {
// the decision tree will generate a jump to the `1` branch here // the decision tree will generate a jump to the `1` branch here
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
@ -2013,4 +2013,65 @@ mod gen_primitives {
i64 i64
); );
} }
#[test]
fn non_nullable_eval_cfold() {
// the decision tree will generate a jump to the `1` branch here
assert_evals_to!(
indoc!(
r#"
app "test" provides [ main ] to "./platform"
Expr : [ Add Expr Expr, Mul Expr Expr, Val I64, Var I64 ]
mkExpr : I64, I64 -> Expr
mkExpr = \n , v ->
when n is
0 -> if v == 0 then Var 1 else Val v
_ -> Add (mkExpr (n-1) (v+1)) (mkExpr (n-1) (max (v-1) 0))
max : I64, I64 -> I64
max = \a, b -> if a > b then a else b
eval : Expr -> I64
eval = \e ->
when e is
Var _ -> 0
Val v -> v
Add l r -> eval l + eval r
Mul l r -> eval l * eval r
constFolding : Expr -> Expr
constFolding = \e ->
when e is
Add e1 e2 ->
x1 = constFolding e1
x2 = constFolding e2
when Pair x1 x2 is
Pair (Val a) (Val b) -> Val (a+b)
Pair (Val a) (Add (Val b) x) -> Add (Val (a+b)) x
Pair (Val a) (Add x (Val b)) -> Add (Val (a+b)) x
Pair _ _ -> Add x1 x2
Mul e1 e2 ->
x1 = constFolding e1
x2 = constFolding e2
when Pair x1 x2 is
Pair (Val a) (Val b) -> Val (a*b)
Pair (Val a) (Mul (Val b) x) -> Mul (Val (a*b)) x
Pair (Val a) (Mul x (Val b)) -> Mul (Val (a*b)) x
Pair _ _ -> Mul x1 x2
_ -> e
main : I64
main = eval (constFolding (mkExpr 3 1))
"#
),
11,
i64
);
}
} }

View file

@ -1477,15 +1477,15 @@ fn decide_to_branching<'a>(
// the cond_layout can change in the process. E.g. if the cond is a Tag, we actually // the cond_layout can change in the process. E.g. if the cond is a Tag, we actually
// switch on the tag discriminant (currently an i64 value) // switch on the tag discriminant (currently an i64 value)
// NOTE the tag discriminant is not actually loaded, `cond` can point to a tag // NOTE the tag discriminant is not actually loaded, `cond` can point to a tag
let (cond, cond_stores_vec, cond_layout) = let (inner_cond_symbol, cond_stores_vec, inner_cond_layout) =
path_to_expr_help(env, cond_symbol, &path, cond_layout); path_to_expr_help(env, cond_symbol, &path, cond_layout.clone());
let default_branch = decide_to_branching( let default_branch = decide_to_branching(
env, env,
procs, procs,
layout_cache, layout_cache,
cond_symbol, inner_cond_symbol,
cond_layout.clone(), inner_cond_layout.clone(),
ret_layout.clone(), ret_layout.clone(),
*fallback, *fallback,
jumps, jumps,
@ -1498,8 +1498,8 @@ fn decide_to_branching<'a>(
env, env,
procs, procs,
layout_cache, layout_cache,
cond_symbol, inner_cond_symbol,
cond_layout.clone(), inner_cond_layout.clone(),
ret_layout.clone(), ret_layout.clone(),
decider, decider,
jumps, jumps,
@ -1517,9 +1517,11 @@ fn decide_to_branching<'a>(
branches.push((tag, branch)); branches.push((tag, branch));
} }
// We have learned more about the exact layout of the cond (based on the path)
// but tests are still relative to the original cond symbol
let mut switch = Stmt::Switch { let mut switch = Stmt::Switch {
cond_layout, cond_layout: inner_cond_layout,
cond_symbol: cond, cond_symbol,
branches: branches.into_bump_slice(), branches: branches.into_bump_slice(),
default_branch: env.arena.alloc(default_branch), default_branch: env.arena.alloc(default_branch),
ret_layout, ret_layout,