fix oversights

This commit is contained in:
Folkert 2021-01-28 15:32:22 +01:00
parent 74e94869e3
commit 55eff1dba1
10 changed files with 87 additions and 128 deletions

View file

@ -2162,19 +2162,18 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
} }
DecRef(symbol) => { DecRef(symbol) => {
let (value, layout) = load_symbol_and_layout(scope, symbol); let (value, layout) = load_symbol_and_layout(scope, symbol);
debug_assert!(layout.is_refcounted());
let value_ptr = value.into_pointer_value(); if layout.is_refcounted() {
let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); let value_ptr = value.into_pointer_value();
refcount_ptr.decrement(env, layout); let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr);
refcount_ptr.decrement(env, layout);
}
build_exp_stmt(env, layout_ids, scope, parent, cont) build_exp_stmt(env, layout_ids, scope, parent, cont)
} }
} }
} }
Info(_, cont) => build_exp_stmt(env, layout_ids, scope, parent, cont),
RuntimeError(error_msg) => { RuntimeError(error_msg) => {
throw_exception(env, error_msg); throw_exception(env, error_msg);

View file

@ -1936,6 +1936,7 @@ mod gen_primitives {
} }
#[test] #[test]
#[ignore]
fn rosetree_basic() { fn rosetree_basic() {
assert_non_opt_evals_to!( assert_non_opt_evals_to!(
indoc!( indoc!(

View file

@ -410,9 +410,6 @@ where
self.set_last_seen(sym, stmt); self.set_last_seen(sym, stmt);
self.scan_ast(following); self.scan_ast(following);
} }
Stmt::Info(_, following) => {
self.scan_ast(following);
}
Stmt::Join { Stmt::Join {
parameters, parameters,
continuation, continuation,

View file

@ -1783,20 +1783,24 @@ fn update<'a>(
.dependencies .dependencies
.notify(module_id, Phase::MakeSpecializations); .notify(module_id, Phase::MakeSpecializations);
state.procedures.extend(procedures);
state.timings.insert(module_id, module_timing);
if state.dependencies.solved_all() && state.goal_phase == Phase::MakeSpecializations { if state.dependencies.solved_all() && state.goal_phase == Phase::MakeSpecializations {
debug_assert!(work.is_empty(), "still work remaining {:?}", &work); debug_assert!(work.is_empty(), "still work remaining {:?}", &work);
Proc::insert_refcount_operations(arena, &mut state.procedures); Proc::insert_refcount_operations(arena, &mut state.procedures);
Proc::optimize_refcount_operations(
arena,
module_id,
&mut ident_ids,
&mut state.procedures,
);
state.procedures.extend(procedures); if false {
Proc::optimize_refcount_operations(
arena,
module_id,
&mut ident_ids,
&mut state.procedures,
);
}
state.constrained_ident_ids.insert(module_id, ident_ids); state.constrained_ident_ids.insert(module_id, ident_ids);
state.timings.insert(module_id, module_timing);
for (module_id, requested) in external_specializations_requested { for (module_id, requested) in external_specializations_requested {
let existing = match state let existing = match state
@ -1836,9 +1840,7 @@ fn update<'a>(
// the originally requested module, we're all done! // the originally requested module, we're all done!
return Ok(state); return Ok(state);
} else { } else {
state.procedures.extend(procedures);
state.constrained_ident_ids.insert(module_id, ident_ids); state.constrained_ident_ids.insert(module_id, ident_ids);
state.timings.insert(module_id, module_timing);
for (module_id, requested) in external_specializations_requested { for (module_id, requested) in external_specializations_requested {
let existing = match state let existing = match state

View file

@ -156,9 +156,6 @@ impl<'a> ParamMap<'a> {
Let(_, _, _, cont) => { Let(_, _, _, cont) => {
stack.push(cont); stack.push(cont);
} }
Info(_, cont) => {
stack.push(cont);
}
Invoke { pass, fail, .. } => { Invoke { pass, fail, .. } => {
stack.push(pass); stack.push(pass);
stack.push(fail); stack.push(fail);
@ -465,10 +462,6 @@ impl<'a> BorrowInfState<'a> {
self.collect_stmt(b); self.collect_stmt(b);
} }
Info(_, cont) => {
self.collect_stmt(cont);
}
Let(x, Expr::FunctionPointer(fsymbol, layout), _, b) => { Let(x, Expr::FunctionPointer(fsymbol, layout), _, b) => {
// ensure that the function pointed to is in the param map // ensure that the function pointed to is in the param map
if let Some(params) = self.param_map.get_symbol(*fsymbol) { if let Some(params) = self.param_map.get_symbol(*fsymbol) {

View file

@ -178,6 +178,7 @@ fn layout_for_constructor<'a>(
layout: &Layout<'a>, layout: &Layout<'a>,
constructor: u64, constructor: u64,
) -> ConstructorLayout<&'a [Layout<'a>]> { ) -> ConstructorLayout<&'a [Layout<'a>]> {
use ConstructorLayout::*;
use Layout::*; use Layout::*;
match layout { match layout {
@ -194,7 +195,21 @@ fn layout_for_constructor<'a>(
ConstructorLayout::HasFields(other_fields) ConstructorLayout::HasFields(other_fields)
} }
} }
_ => todo!(), NullableWrapped {
nullable_id,
other_tags,
} => {
if constructor as i64 == *nullable_id {
ConstructorLayout::IsNull
} else {
ConstructorLayout::HasFields(other_tags[constructor as usize])
}
}
NonRecursive(fields) | Recursive(fields) => HasFields(fields[constructor as usize]),
NonNullableUnwrapped(fields) => {
debug_assert_eq!(constructor, 0);
HasFields(fields)
}
} }
} }
_ => unreachable!(), _ => unreachable!(),
@ -404,22 +419,6 @@ pub fn expand_and_cancel<'a>(env: &mut Env<'a, '_>, stmt: &'a Stmt<'a>) -> &'a S
expand_and_cancel(env, cont) expand_and_cancel(env, cont)
} }
Info(info, cont) => {
env.constructor_map
.insert(info.scrutinee, info.tag_id as u64);
env.layout_map.insert(info.scrutinee, info.layout.clone());
let cont = expand_and_cancel(env, cont);
env.constructor_map.remove(&info.scrutinee);
env.layout_map.remove(&info.scrutinee);
let stmt = Info(info.clone(), cont);
env.arena.alloc(stmt)
}
Invoke { Invoke {
symbol, symbol,
call, call,
@ -469,6 +468,27 @@ pub fn expand_and_cancel<'a>(env: &mut Env<'a, '_>, stmt: &'a Stmt<'a>) -> &'a S
result = env.arena.alloc(stmt); result = env.arena.alloc(stmt);
} }
// do all decrements
for (symbol, amount) in deferred.inc_dec_map.iter().rev() {
use std::cmp::Ordering;
match amount.cmp(&0) {
Ordering::Equal => {
// do nothing else
}
Ordering::Greater => {
// do nothing yet
}
Ordering::Less => {
// the RC insertion should not double decrement in a block
debug_assert_eq!(*amount, -1);
// insert missing decrements
let stmt = Refcounting(ModifyRc::Dec(*symbol), result);
result = env.arena.alloc(stmt);
}
}
}
for (symbol, amount) in deferred.inc_dec_map.into_iter().rev() { for (symbol, amount) in deferred.inc_dec_map.into_iter().rev() {
use std::cmp::Ordering; use std::cmp::Ordering;
match amount.cmp(&0) { match amount.cmp(&0) {
@ -481,12 +501,7 @@ pub fn expand_and_cancel<'a>(env: &mut Env<'a, '_>, stmt: &'a Stmt<'a>) -> &'a S
result = env.arena.alloc(stmt); result = env.arena.alloc(stmt);
} }
Ordering::Less => { Ordering::Less => {
// the RC insertion should not double decrement in a block // already done
debug_assert_eq!(amount, -1);
// insert missing decrements
let stmt = Refcounting(ModifyRc::Dec(symbol), result);
result = env.arena.alloc(stmt);
} }
} }
} }

View file

@ -32,10 +32,6 @@ pub fn occuring_variables(stmt: &Stmt<'_>) -> (MutSet<Symbol>, MutSet<Symbol>) {
stack.push(cont); stack.push(cont);
} }
Info(_, cont) => {
stack.push(cont);
}
Invoke { Invoke {
symbol, symbol,
call, call,
@ -712,16 +708,6 @@ impl<'a> Context<'a> {
) )
} }
Info(info, cont) => {
let (cont, live_vars) = self.visit_stmt(cont);
let stmt = Info(info.clone(), cont);
let stmt = self.arena.alloc(stmt);
(stmt, live_vars)
}
Invoke { Invoke {
symbol, symbol,
call, call,
@ -928,20 +914,6 @@ pub fn collect_stmt(
collect_stmt(cont, jp_live_vars, vars) collect_stmt(cont, jp_live_vars, vars)
} }
Info(_, cont) => collect_stmt(cont, jp_live_vars, vars),
Jump(id, arguments) => {
vars.extend(arguments.iter().copied());
// NOTE deviation from Lean
// we fall through when no join point is available
if let Some(jvars) = jp_live_vars.get(id) {
vars.extend(jvars);
}
vars
}
Join { Join {
id: j, id: j,
parameters, parameters,
@ -959,6 +931,18 @@ pub fn collect_stmt(
collect_stmt(b, &jp_live_vars, vars) collect_stmt(b, &jp_live_vars, vars)
} }
Jump(id, arguments) => {
vars.extend(arguments.iter().copied());
// NOTE deviation from Lean
// we fall through when no join point is available
if let Some(jvars) = jp_live_vars.get(id) {
vars.extend(jvars);
}
vars
}
Switch { Switch {
cond_symbol, cond_symbol,
branches, branches,

View file

@ -17,7 +17,7 @@ use roc_types::subs::{Content, FlatType, Subs, Variable};
use std::collections::HashMap; use std::collections::HashMap;
use ven_pretty::{BoxAllocator, DocAllocator, DocBuilder}; use ven_pretty::{BoxAllocator, DocAllocator, DocBuilder};
pub const PRETTY_PRINT_IR_SYMBOLS: bool = true; pub const PRETTY_PRINT_IR_SYMBOLS: bool = false;
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum MonoProblem { pub enum MonoProblem {
@ -796,7 +796,6 @@ pub enum Stmt<'a> {
}, },
Ret(Symbol), Ret(Symbol),
Rethrow, Rethrow,
Info(ConstructorInfo<'a>, &'a Stmt<'a>),
Refcounting(ModifyRc, &'a Stmt<'a>), Refcounting(ModifyRc, &'a Stmt<'a>),
Join { Join {
id: JoinPointId, id: JoinPointId,
@ -895,14 +894,6 @@ impl ModifyRc {
} }
} }
/// in the block below, symbol `scrutinee` is assumed be be of shape `tag_id`
#[derive(Clone, Debug, PartialEq)]
pub struct ConstructorInfo<'a> {
pub scrutinee: Symbol,
pub layout: Layout<'a>,
pub tag_id: u8,
}
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum Literal<'a> { pub enum Literal<'a> {
// Literals // Literals
@ -1256,28 +1247,13 @@ impl<'a> Stmt<'a> {
.text("let ") .text("let ")
.append(symbol_to_doc(alloc, *symbol)) .append(symbol_to_doc(alloc, *symbol))
//.append(" : ") //.append(" : ")
//.append(alloc.text(format!("{:?}", layout))) //.append(alloc.text(format!("{:?}", _layout)))
.append(" = ") .append(" = ")
.append(expr.to_doc(alloc)) .append(expr.to_doc(alloc))
.append(";") .append(";")
.append(alloc.hardline()) .append(alloc.hardline())
.append(cont.to_doc(alloc)), .append(cont.to_doc(alloc)),
Info(info, cont) => alloc
.text("info:")
.append(alloc.hardline())
.append(alloc.text(" scrutinee: "))
.append(symbol_to_doc(alloc, info.scrutinee))
.append(alloc.hardline())
.append(alloc.text(" tag_id: "))
.append(format!("{:?}", info.tag_id))
.append(alloc.hardline())
.append(alloc.text(" layout: "))
.append(format!("{:?}", info.layout))
.append(";")
.append(alloc.hardline())
.append(cont.to_doc(alloc)),
Refcounting(modify, cont) => modify Refcounting(modify, cont) => modify
.to_doc(alloc) .to_doc(alloc)
.append(alloc.hardline()) .append(alloc.hardline())
@ -4844,10 +4820,6 @@ fn substitute_in_stmt_help<'a>(
None => None, None => None,
} }
} }
Info(info, cont) => match substitute_in_stmt_help(arena, cont, subs) {
Some(cont) => Some(arena.alloc(Info(info.clone(), cont))),
None => None,
},
Jump(id, args) => { Jump(id, args) => {
let mut did_change = false; let mut did_change = false;

View file

@ -235,10 +235,6 @@ fn insert_jumps<'a>(
Some(cont) => Some(arena.alloc(Refcounting(*modify, cont))), Some(cont) => Some(arena.alloc(Refcounting(*modify, cont))),
None => None, None => None,
}, },
Info(info, cont) => match insert_jumps(arena, cont, goal_id, needle) {
Some(cont) => Some(arena.alloc(Info(info.clone(), cont))),
None => None,
},
Rethrow => None, Rethrow => None,
Ret(_) => None, Ret(_) => None,

View file

@ -856,14 +856,14 @@ mod test_mono {
joinpoint Test.8 Test.3: joinpoint Test.8 Test.3:
ret Test.3; ret Test.3;
in in
let Test.12 = 1i64; let Test.15 = 1i64;
let Test.13 = Index 0 Test.2; let Test.16 = Index 0 Test.2;
let Test.17 = lowlevel Eq Test.12 Test.13; let Test.17 = lowlevel Eq Test.15 Test.16;
if Test.17 then if Test.17 then
let Test.14 = Index 1 Test.2; let Test.12 = Index 1 Test.2;
let Test.15 = 3i64; let Test.13 = 3i64;
let Test.16 = lowlevel Eq Test.15 Test.14; let Test.14 = lowlevel Eq Test.13 Test.12;
if Test.16 then if Test.14 then
let Test.9 = 1i64; let Test.9 = 1i64;
jump Test.8 Test.9; jump Test.8 Test.9;
else else
@ -1933,17 +1933,17 @@ mod test_mono {
let Test.16 = S Test.19 Test.18; let Test.16 = S Test.19 Test.18;
let Test.14 = S Test.17 Test.16; let Test.14 = S Test.17 Test.16;
let Test.2 = S Test.15 Test.14; let Test.2 = S Test.15 Test.14;
let Test.7 = 0i64; let Test.11 = 0i64;
let Test.8 = Index 0 Test.2; let Test.12 = Index 0 Test.2;
let Test.13 = lowlevel Eq Test.7 Test.8; let Test.13 = lowlevel Eq Test.11 Test.12;
if Test.13 then if Test.13 then
let Test.9 = Index 1 Test.2; let Test.7 = Index 1 Test.2;
inc Test.9; inc Test.7;
let Test.10 = 0i64; let Test.8 = 0i64;
let Test.11 = Index 0 Test.9; let Test.9 = Index 0 Test.7;
dec Test.9; dec Test.7;
let Test.12 = lowlevel Eq Test.10 Test.11; let Test.10 = lowlevel Eq Test.8 Test.9;
if Test.12 then if Test.10 then
let Test.4 = Index 1 Test.2; let Test.4 = Index 1 Test.2;
inc Test.4; inc Test.4;
dec Test.2; dec Test.2;