add logic to generate/solve closure size constraints

This commit is contained in:
Folkert 2020-10-02 20:46:47 +02:00
parent bb6f36ad28
commit 05d1f28e83
8 changed files with 455 additions and 2 deletions

View file

@ -6,7 +6,7 @@ use crate::num::{
finish_parsing_base, finish_parsing_float, finish_parsing_int, float_expr_from_result, finish_parsing_base, finish_parsing_float, finish_parsing_int, float_expr_from_result,
int_expr_from_result, num_expr_from_result, int_expr_from_result, num_expr_from_result,
}; };
use crate::pattern::{canonicalize_pattern, Pattern}; use crate::pattern::{canonicalize_pattern, symbols_from_pattern, Pattern};
use crate::procedure::References; use crate::procedure::References;
use crate::scope::Scope; use crate::scope::Scope;
use inlinable_string::InlinableString; use inlinable_string::InlinableString;
@ -1525,3 +1525,130 @@ pub fn unescape_char(escaped: &EscapedChar) -> char {
Newline => '\n', Newline => '\n',
} }
} }
pub fn free_variables(expr: &Expr) -> MutSet<Symbol> {
use Expr::*;
let mut stack = vec![expr.clone()];
let mut bound = MutSet::default();
let mut used: std::collections::HashSet<roc_module::symbol::Symbol, _> = MutSet::default();
while let Some(expr) = stack.pop() {
match expr {
Num(_, _) | Int(_, _) | Float(_, _) | Str(_) | EmptyRecord | RuntimeError(_) => {}
Var(s) => {
used.insert(s);
}
List { loc_elems, .. } => {
for e in loc_elems {
stack.push(e.value);
}
}
When {
loc_cond, branches, ..
} => {
stack.push(loc_cond.value);
for branch in branches {
stack.push(branch.value.value);
if let Some(guard) = branch.guard {
stack.push(guard.value);
}
bound.extend(
branch
.patterns
.iter()
.map(|t| symbols_from_pattern(&t.value).into_iter())
.flatten(),
);
}
}
If {
branches,
final_else,
..
} => {
for (cond, then) in branches {
stack.push(cond.value);
stack.push(then.value);
}
stack.push(final_else.value);
}
LetRec(defs, cont, _, _) => {
stack.push(cont.value);
for def in defs {
bound.extend(symbols_from_pattern(&def.loc_pattern.value));
stack.push(def.loc_expr.value);
}
}
LetNonRec(def, cont, _, _) => {
stack.push(cont.value);
bound.extend(symbols_from_pattern(&def.loc_pattern.value));
stack.push(def.loc_expr.value);
}
Call(boxed, args, _) => {
let (_, function, _, _) = *boxed;
stack.push(function.value);
for (_, arg) in args {
stack.push(arg.value);
}
}
RunLowLevel { args, .. } => {
for (_, arg) in args {
stack.push(arg);
}
}
Closure {
arguments: args,
loc_body: boxed_body,
..
} => {
bound.extend(
args.iter()
.map(|t| symbols_from_pattern(&t.1.value).into_iter())
.flatten(),
);
stack.push(boxed_body.value);
}
Record { fields, .. } => {
for (_, field) in fields {
stack.push(field.loc_expr.value);
}
}
Update {
symbol, updates, ..
} => {
used.insert(symbol);
for (_, field) in updates {
stack.push(field.loc_expr.value);
}
}
Access { loc_expr, .. } => {
stack.push(loc_expr.value);
}
Accessor { .. } => {}
Tag { arguments, .. } => {
for (_, arg) in arguments {
stack.push(arg.value);
}
}
}
}
for b in bound {
used.remove(&b);
}
used
}

View file

@ -12,6 +12,8 @@ roc_module = { path = "../module" }
roc_types = { path = "../types" } roc_types = { path = "../types" }
roc_can = { path = "../can" } roc_can = { path = "../can" }
roc_unify = { path = "../unify" } roc_unify = { path = "../unify" }
roc_constrain = { path = "../constrain" }
roc_solve = { path = "../solve" }
roc_problem = { path = "../problem" } roc_problem = { path = "../problem" }
ven_pretty = { path = "../../vendor/pretty" } ven_pretty = { path = "../../vendor/pretty" }
bumpalo = { version = "3.2", features = ["collections"] } bumpalo = { version = "3.2", features = ["collections"] }

View file

@ -0,0 +1,306 @@
use roc_can::constraint::Constraint;
use roc_can::expected::Expected;
use roc_can::expr::Expr;
use roc_can::pattern::symbols_from_pattern;
use roc_collections::all::MutSet;
use roc_constrain::expr::exists;
use roc_module::symbol::Symbol;
use roc_region::all::Region;
use roc_types::subs::{Subs, VarStore};
use roc_types::types::{Category, Type};
pub fn infer_closure_size(expr: &Expr, mut subs: Subs) -> Subs {
use roc_solve::solve;
let mut var_store = VarStore::new_from_subs(&mut subs);
let env = solve::Env::default();
let mut problems = Vec::new();
let constraint = generate_constraint(expr, &mut var_store);
let (solved_subs, _new_env) = solve::run(&env, &mut problems, subs, &constraint);
debug_assert_eq!(problems.len(), 0);
solved_subs.0
}
pub fn free_variables(expr: &Expr) -> MutSet<Symbol> {
use Expr::*;
let mut stack = vec![expr.clone()];
let mut bound = MutSet::default();
let mut used: std::collections::HashSet<roc_module::symbol::Symbol, _> = MutSet::default();
while let Some(expr) = stack.pop() {
match expr {
Num(_, _) | Int(_, _) | Float(_, _) | Str(_) | EmptyRecord | RuntimeError(_) => {}
Var(s) => {
used.insert(s);
}
List { loc_elems, .. } => {
for e in loc_elems {
stack.push(e.value);
}
}
When {
loc_cond, branches, ..
} => {
stack.push(loc_cond.value);
for branch in branches {
stack.push(branch.value.value);
if let Some(guard) = branch.guard {
stack.push(guard.value);
}
bound.extend(
branch
.patterns
.iter()
.map(|t| symbols_from_pattern(&t.value).into_iter())
.flatten(),
);
}
}
If {
branches,
final_else,
..
} => {
for (cond, then) in branches {
stack.push(cond.value);
stack.push(then.value);
}
stack.push(final_else.value);
}
LetRec(defs, cont, _, _) => {
stack.push(cont.value);
for def in defs {
bound.extend(symbols_from_pattern(&def.loc_pattern.value));
stack.push(def.loc_expr.value);
}
}
LetNonRec(def, cont, _, _) => {
stack.push(cont.value);
bound.extend(symbols_from_pattern(&def.loc_pattern.value));
stack.push(def.loc_expr.value);
}
Call(boxed, args, _) => {
let (_, function, _, _) = *boxed;
stack.push(function.value);
for (_, arg) in args {
stack.push(arg.value);
}
}
RunLowLevel { args, .. } => {
for (_, arg) in args {
stack.push(arg);
}
}
Closure {
arguments: args,
loc_body: boxed_body,
..
} => {
bound.extend(
args.iter()
.map(|t| symbols_from_pattern(&t.1.value).into_iter())
.flatten(),
);
stack.push(boxed_body.value);
}
Record { fields, .. } => {
for (_, field) in fields {
stack.push(field.loc_expr.value);
}
}
Update {
symbol, updates, ..
} => {
used.insert(symbol);
for (_, field) in updates {
stack.push(field.loc_expr.value);
}
}
Access { loc_expr, .. } => {
stack.push(loc_expr.value);
}
Accessor { .. } => {}
Tag { arguments, .. } => {
for (_, arg) in arguments {
stack.push(arg.value);
}
}
}
}
for b in bound {
used.remove(&b);
}
used
}
pub fn generate_constraint(expr: &Expr, var_store: &mut VarStore) -> Constraint {
let mut constraints = Vec::new();
generate_constraints_help(expr, var_store, &mut constraints);
Constraint::And(constraints)
}
pub fn generate_constraints_help(
expr: &Expr,
var_store: &mut VarStore,
constraints: &mut Vec<Constraint>,
) {
use Expr::*;
match expr {
Num(_, _) | Int(_, _) | Float(_, _) | Str(_) | EmptyRecord | RuntimeError(_) => {}
Var(_) => {}
List { loc_elems, .. } => {
for e in loc_elems {
generate_constraints_help(&e.value, var_store, constraints);
}
}
When {
loc_cond, branches, ..
} => {
generate_constraints_help(&loc_cond.value, var_store, constraints);
for branch in branches {
generate_constraints_help(&branch.value.value, var_store, constraints);
if let Some(guard) = &branch.guard {
generate_constraints_help(&guard.value, var_store, constraints);
}
}
}
If {
branches,
final_else,
..
} => {
for (cond, then) in branches {
generate_constraints_help(&cond.value, var_store, constraints);
generate_constraints_help(&then.value, var_store, constraints);
}
generate_constraints_help(&final_else.value, var_store, constraints);
}
LetRec(defs, cont, _, _) => {
generate_constraints_help(&cont.value, var_store, constraints);
for def in defs {
generate_constraints_help(&def.loc_expr.value, var_store, constraints);
}
}
LetNonRec(def, cont, _, _) => {
generate_constraints_help(&cont.value, var_store, constraints);
generate_constraints_help(&def.loc_expr.value, var_store, constraints);
}
Call(boxed, args, _) => {
let (_, function, _, _) = &**boxed;
generate_constraints_help(&function.value, var_store, constraints);
for (_, arg) in args {
generate_constraints_help(&arg.value, var_store, constraints);
}
}
RunLowLevel { args, .. } => {
for (_, arg) in args {
generate_constraints_help(&arg, var_store, constraints);
}
}
Closure {
arguments: _,
closure_type: closure_var,
loc_body: boxed_body,
..
} => {
let mut cons = Vec::new();
let mut variables = Vec::new();
let closed_over_symbols = MutSet::default();
let closure_ext_var = var_store.fresh();
let closure_var = *closure_var;
variables.push(closure_ext_var);
// TODO unsure about including this one
variables.push(closure_var);
let mut tag_arguments = Vec::with_capacity(closed_over_symbols.len());
for symbol in closed_over_symbols {
let var = var_store.fresh();
variables.push(var);
tag_arguments.push(Type::Variable(var));
let region = Region::zero();
let expected = Expected::NoExpectation(Type::Variable(var));
let lookup = Constraint::Lookup(symbol, expected, region);
cons.push(lookup);
}
let tag_name_string = format!("Closure_{}", closure_var.index());
let tag_name = roc_module::ident::TagName::Global(tag_name_string.into());
let expected_type = Type::TagUnion(
vec![(tag_name, tag_arguments)],
Box::new(Type::Variable(closure_ext_var)),
);
// constrain this closures's size to the type we just created
let expected = Expected::NoExpectation(expected_type);
let category = Category::ClosureSize;
let region = boxed_body.region;
let equality = Constraint::Eq(Type::Variable(closure_var), expected, category, region);
cons.push(equality);
// generate constraints for nested closures
let mut inner_constraints = Vec::new();
generate_constraints_help(&boxed_body.value, var_store, &mut inner_constraints);
cons.push(Constraint::And(inner_constraints));
let constraint = exists(variables, Constraint::And(cons));
constraints.push(constraint);
}
Record { fields, .. } => {
for (_, field) in fields {
generate_constraints_help(&field.loc_expr.value, var_store, constraints);
}
}
Update {
symbol: _, updates, ..
} => {
for (_, field) in updates {
generate_constraints_help(&field.loc_expr.value, var_store, constraints);
}
}
Access { loc_expr, .. } => {
generate_constraints_help(&loc_expr.value, var_store, constraints);
}
Accessor { .. } => {}
Tag { arguments, .. } => {
for (_, arg) in arguments {
generate_constraints_help(&arg.value, var_store, constraints);
}
}
}
}

View file

@ -12,6 +12,7 @@
#![allow(clippy::large_enum_variant)] #![allow(clippy::large_enum_variant)]
pub mod borrow; pub mod borrow;
pub mod closures;
pub mod inc_dec; pub mod inc_dec;
pub mod ir; pub mod ir;
pub mod layout; pub mod layout;

View file

@ -882,6 +882,11 @@ fn add_category<'b>(
Lambda => alloc.concat(vec![this_is, alloc.text(" an anonymous function of type:")]), Lambda => alloc.concat(vec![this_is, alloc.text(" an anonymous function of type:")]),
ClosureSize => alloc.concat(vec![
this_is,
alloc.text(" the closure size of a function of type:"),
]),
TagApply { TagApply {
tag_name: TagName::Global(name), tag_name: TagName::Global(name),
args_count: 0, args_count: 0,

View file

@ -70,7 +70,7 @@ pub enum TypeError {
BadType(roc_types::types::Problem), BadType(roc_types::types::Problem),
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug, Default)]
pub struct Env { pub struct Env {
pub vars_by_symbol: SendMap<Symbol, Variable>, pub vars_by_symbol: SendMap<Symbol, Variable>,
pub aliases: MutMap<Symbol, Alias>, pub aliases: MutMap<Symbol, Alias>,

View file

@ -72,6 +72,13 @@ impl VarStore {
VarStore { next: next_var.0 } VarStore { next: next_var.0 }
} }
pub fn new_from_subs(subs: &Subs) -> Self {
let next_var = subs.utable.len() as u32;
debug_assert!(next_var >= Variable::FIRST_USER_SPACE_VAR.0);
VarStore { next: next_var }
}
pub fn fresh(&mut self) -> Variable { pub fn fresh(&mut self) -> Variable {
// Increment the counter and return the value it had before it was incremented. // Increment the counter and return the value it had before it was incremented.
let answer = self.next; let answer = self.next;
@ -163,6 +170,10 @@ impl Variable {
pub unsafe fn unsafe_test_debug_variable(v: u32) -> Self { pub unsafe fn unsafe_test_debug_variable(v: u32) -> Self {
Variable(v) Variable(v)
} }
pub fn index(&self) -> u32 {
self.0
}
} }
impl Into<OptVariable> for Variable { impl Into<OptVariable> for Variable {

View file

@ -942,6 +942,7 @@ pub enum Category {
}, },
Lambda, Lambda,
Uniqueness, Uniqueness,
ClosureSize,
StrInterpolation, StrInterpolation,
// storing variables in the ast // storing variables in the ast