allow lowlevel and match

This commit is contained in:
J.Teeuwissen 2023-05-24 14:29:17 +02:00
parent 9b58c0fb9c
commit d7304f86e5
No known key found for this signature in database
GPG key ID: DB5F7A1ED8D478AD
4 changed files with 326 additions and 9 deletions

View file

@ -108,9 +108,18 @@ fn specialize_drops_stmt<'a, 'i>(
alloc_let_with_continuation!(environment)
}
_ => {
// TODO perhaps allow for some e.g. lowlevel functions to be called if they cannot modify the RC of the symbol.
// Check whether the increments can be passed to the continuation.
CallType::LowLevel { op, .. } => match low_level_no_rc(&op) {
// It should be safe to pass the increments to the continuation.
RC::NoRc => alloc_let_with_continuation!(environment),
// We probably should not pass the increments to the continuation.
RC::Rc | RC::Uknown => {
let mut new_environment = environment.clone_without_incremented();
alloc_let_with_continuation!(&mut new_environment)
}
},
_ => {
// Calls can modify the RC of the symbol.
// If we move a increment of children after the function,
// the function might deallocate the child before we can use it after the function.
@ -222,7 +231,7 @@ fn specialize_drops_stmt<'a, 'i>(
let new_branches = branches
.iter()
.map(|(label, info, branch)| {
let mut branch_env = environment.clone_without_incremented();
let mut branch_env = environment.clone();
insert_branch_info!(branch_env, info);
@ -234,7 +243,7 @@ fn specialize_drops_stmt<'a, 'i>(
branch,
);
(*label, info.clone(), new_branch.clone())
(*label, info.clone(), new_branch.clone(), branch_env)
})
.collect_in::<Vec<_>>(arena)
.into_bump_slice();
@ -242,7 +251,7 @@ fn specialize_drops_stmt<'a, 'i>(
let new_default_branch = {
let (info, branch) = default_branch;
let mut branch_env = environment.clone_without_incremented();
let mut branch_env = environment.clone();
insert_branch_info!(branch_env, info);
@ -254,14 +263,91 @@ fn specialize_drops_stmt<'a, 'i>(
branch,
);
(info.clone(), new_branch, branch_env)
};
// Find consumed increments in each branch and make sure they are consumed in all branches.
// By incrementing them in each branch where they were not consumed.
{
let branch_envs = {
let mut branch_environments =
Vec::with_capacity_in(new_branches.len() + 1, arena);
for (_, _, _, branch_env) in new_branches.iter() {
branch_environments.push(branch_env);
}
branch_environments.push(&new_default_branch.2);
branch_environments
};
// Find the lowest symbol count for each symbol in each branch, and update the environment to match.
for (symbol, count) in environment.incremented_symbols.iter_mut() {
let consumed = branch_envs
.iter()
.map(|branch_env| branch_env.incremented_symbols.get(symbol).unwrap_or(&0))
.min()
.unwrap();
// TODO verify this works.
*count = *consumed;
}
}
macro_rules! insert_incs {
($branch_env:expr, $branch:expr ) => {{
let symbol_differences =
environment
.incremented_symbols
.iter()
.filter_map(|(symbol, count)| {
let branch_count =
$branch_env.incremented_symbols.get(symbol).unwrap_or(&0);
match count - branch_count {
0 => None,
difference => Some(difference),
}
});
symbol_differences.fold($branch, |new_branch, difference| {
arena.alloc(Stmt::Refcounting(
ModifyRc::Inc(*cond_symbol, difference),
new_branch,
))
})
}};
}
let newer_branches = new_branches
.into_iter()
.map(|(label, info, branch, branch_env)| {
let new_branch = insert_incs!(branch_env, branch);
(*label, info.clone(), new_branch.clone())
})
.collect_in::<Vec<_>>(arena)
.into_bump_slice();
let newer_default_branch = {
let (info, branch, branch_env) = new_default_branch;
let new_branch = insert_incs!(branch_env, branch);
(info.clone(), new_branch)
};
// Remove all 0 counts as cleanup.
environment
.incremented_symbols
.retain(|_, count| *count > 0);
arena.alloc(Stmt::Switch {
cond_symbol: *cond_symbol,
cond_layout: *cond_layout,
branches: new_branches,
default_branch: new_default_branch,
branches: newer_branches,
default_branch: newer_default_branch,
ret_layout: *ret_layout,
})
}
@ -1291,3 +1377,111 @@ impl<'a> DropSpecializationEnvironment<'a> {
// TODO assert that a parent is only inlined once / assert max single dec per parent.
}
/**
Reference count information
*/
enum RC {
// Rc is important, moving an increment to after this function might break the program.
// E.g. if the function checks for uniqueness and behaves differently based on that.
Rc,
// Rc is not important, moving an increment to after this function should have no effect.
NoRc,
// Rc effect is unknown.
Uknown,
}
/*
Returns whether the reference count of arguments to this function is relevant to the program.
*/
fn low_level_no_rc(lowlevel: &LowLevel) -> RC {
use LowLevel::*;
match lowlevel {
Unreachable => RC::Uknown,
ListLen | StrIsEmpty | StrToScalars | StrCountGraphemes | StrGraphemes
| StrCountUtf8Bytes | StrGetCapacity | ListGetCapacity => RC::NoRc,
ListWithCapacity | StrWithCapacity => RC::NoRc,
ListReplaceUnsafe => RC::Rc,
StrGetUnsafe | ListGetUnsafe => RC::NoRc,
ListConcat => RC::Rc,
StrConcat => RC::Rc,
StrSubstringUnsafe => RC::NoRc,
StrReserve => RC::Rc,
StrAppendScalar => RC::Rc,
StrGetScalarUnsafe => RC::NoRc,
StrTrim => RC::Rc,
StrTrimLeft => RC::Rc,
StrTrimRight => RC::Rc,
StrSplit => RC::NoRc,
StrToNum => RC::NoRc,
ListPrepend => RC::Rc,
StrJoinWith => RC::NoRc,
ListMap | ListMap2 | ListMap3 | ListMap4 | ListSortWith => RC::Rc,
ListAppendUnsafe
| ListReserve
| ListSublist
| ListDropAt
| ListSwap
| ListReleaseExcessCapacity
| StrReleaseExcessCapacity => RC::Rc,
Eq | NotEq => RC::NoRc,
And | Or | NumAdd | NumAddWrap | NumAddChecked | NumAddSaturated | NumSub | NumSubWrap
| NumSubChecked | NumSubSaturated | NumMul | NumMulWrap | NumMulSaturated
| NumMulChecked | NumGt | NumGte | NumLt | NumLte | NumCompare | NumDivFrac
| NumDivTruncUnchecked | NumDivCeilUnchecked | NumRemUnchecked | NumIsMultipleOf
| NumPow | NumPowInt | NumBitwiseAnd | NumBitwiseXor | NumBitwiseOr | NumShiftLeftBy
| NumShiftRightBy | NumShiftRightZfBy => RC::NoRc,
NumToStr
| NumAbs
| NumNeg
| NumSin
| NumCos
| NumSqrtUnchecked
| NumLogUnchecked
| NumRound
| NumCeiling
| NumFloor
| NumToFrac
| Not
| NumIsNan
| NumIsInfinite
| NumIsFinite
| NumAtan
| NumAcos
| NumAsin
| NumIntCast
| NumToIntChecked
| NumToFloatCast
| NumToFloatChecked
| NumCountLeadingZeroBits
| NumCountTrailingZeroBits
| NumCountOneBits => RC::NoRc,
NumBytesToU16 => RC::NoRc,
NumBytesToU32 => RC::NoRc,
NumBytesToU64 => RC::NoRc,
NumBytesToU128 => RC::NoRc,
StrStartsWith | StrEndsWith => RC::NoRc,
StrStartsWithScalar => RC::NoRc,
StrFromUtf8Range => RC::Rc,
StrToUtf8 => RC::Rc,
StrRepeat => RC::NoRc,
StrFromInt | StrFromFloat => RC::NoRc,
Hash => RC::NoRc,
ListIsUnique => RC::Rc,
BoxExpr | UnboxExpr => {
unreachable!("These lowlevel operations are turned into mono Expr's")
}
PtrCast | PtrWrite | RefCountIncRcPtr | RefCountDecRcPtr | RefCountIncDataPtr
| RefCountDecDataPtr | RefCountIsUnique => {
unreachable!("Only inserted *after* borrow checking: {:?}", lowlevel);
}
}
}

View file

@ -2736,9 +2736,16 @@ impl<'a> LayoutRepr<'a> {
RecursivePointer(_) => true,
Builtin(List(_)) | Builtin(Str) => true,
Builtin(builtin ) => match builtin {
Int(_) | Float(_) | Bool | Decimal => false,
Str | List(_) => true,
}
Boxed(_) => true,
_ => false,
Struct(_) => false,
LambdaSet(_) => false,
}
}

View file

@ -0,0 +1,79 @@
procedure Num.19 (#Attr.2, #Attr.3):
let Num.282 : U64 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Num.282;
procedure Num.24 (#Attr.2, #Attr.3):
let Num.283 : Int1 = lowlevel NumGt #Attr.2 #Attr.3;
ret Num.283;
procedure Test.2 (Test.9, Test.10):
let Test.38 : U8 = 1i64;
let Test.39 : U8 = GetTagId Test.9;
let Test.40 : Int1 = lowlevel Eq Test.38 Test.39;
if Test.40 then
let Test.20 : U64 = CallByName Test.3 Test.10;
ret Test.20;
else
let Test.11 : Str = UnionAtIndex (Id 0) (Index 0) Test.9;
let Test.12 : [<rnu><null>, C Str *self] = UnionAtIndex (Id 0) (Index 1) Test.9;
let Test.35 : U8 = 1i64;
let Test.36 : U8 = GetTagId Test.10;
let Test.37 : Int1 = lowlevel Eq Test.35 Test.36;
if Test.37 then
let Test.29 : U64 = CallByName Test.3 Test.9;
ret Test.29;
else
joinpoint #Derived_gen.3:
let Test.13 : Str = UnionAtIndex (Id 0) (Index 0) Test.10;
let Test.14 : [<rnu><null>, C Str *self] = UnionAtIndex (Id 0) (Index 1) Test.10;
let Test.33 : U64 = CallByName Test.3 Test.12;
let Test.34 : U64 = 1i64;
let Test.15 : U64 = CallByName Num.19 Test.33 Test.34;
let Test.16 : U64 = CallByName Test.3 Test.10;
let Test.31 : Int1 = CallByName Num.24 Test.15 Test.16;
if Test.31 then
ret Test.15;
else
ret Test.16;
in
let #Derived_gen.4 : Int1 = lowlevel RefCountIsUnique Test.9;
if #Derived_gen.4 then
dec Test.11;
decref Test.9;
jump #Derived_gen.3;
else
inc Test.12;
decref Test.9;
jump #Derived_gen.3;
procedure Test.3 (Test.17):
let Test.26 : U8 = 1i64;
let Test.27 : U8 = GetTagId Test.17;
let Test.28 : Int1 = lowlevel Eq Test.26 Test.27;
if Test.28 then
let Test.22 : U64 = 0i64;
ret Test.22;
else
let Test.18 : [<rnu><null>, C Str *self] = UnionAtIndex (Id 0) (Index 1) Test.17;
joinpoint #Derived_gen.0:
let Test.24 : U64 = 1i64;
let Test.25 : U64 = CallByName Test.3 Test.18;
let Test.23 : U64 = CallByName Num.19 Test.24 Test.25;
ret Test.23;
in
let #Derived_gen.2 : Int1 = lowlevel RefCountIsUnique Test.17;
if #Derived_gen.2 then
let #Derived_gen.1 : Str = UnionAtIndex (Id 0) (Index 0) Test.17;
dec #Derived_gen.1;
decref Test.17;
jump #Derived_gen.0;
else
inc Test.18;
decref Test.17;
jump #Derived_gen.0;
procedure Test.0 ():
let Test.5 : [<rnu><null>, C Str *self] = TagId(1) ;
let Test.6 : [<rnu><null>, C Str *self] = TagId(1) ;
let Test.19 : U64 = CallByName Test.2 Test.5 Test.6;
ret Test.19;

View file

@ -3013,3 +3013,40 @@ fn rb_tree_fbip() {
"#
)
}
#[mono_test]
fn specialize_after_match() {
indoc!(
r#"
app "test" provides [main] to "./platform"
main =
listA : LinkedList Str
listA = Nil
listB : LinkedList Str
listB = Nil
longestLinkedList listA listB
LinkedList a : [Cons a (LinkedList a), Nil]
longestLinkedList : LinkedList a, LinkedList a -> Nat
longestLinkedList = \listA, listB -> when listA is
Nil -> linkedListLength listB
Cons a aa -> when listB is
Nil -> linkedListLength listA
Cons b bb ->
lengthA = (linkedListLength aa) + 1
lengthB = linkedListLength listB
if lengthA > lengthB
then lengthA
else lengthB
linkedListLength : LinkedList a -> Nat
linkedListLength = \list -> when list is
Nil -> 0
Cons _ rest -> 1 + linkedListLength rest
"#
)
}