mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-28 14:24:45 +00:00
add logic to generate/solve closure size constraints
This commit is contained in:
parent
bb6f36ad28
commit
05d1f28e83
8 changed files with 455 additions and 2 deletions
|
@ -6,7 +6,7 @@ use crate::num::{
|
||||||
finish_parsing_base, finish_parsing_float, finish_parsing_int, float_expr_from_result,
|
finish_parsing_base, finish_parsing_float, finish_parsing_int, float_expr_from_result,
|
||||||
int_expr_from_result, num_expr_from_result,
|
int_expr_from_result, num_expr_from_result,
|
||||||
};
|
};
|
||||||
use crate::pattern::{canonicalize_pattern, Pattern};
|
use crate::pattern::{canonicalize_pattern, symbols_from_pattern, Pattern};
|
||||||
use crate::procedure::References;
|
use crate::procedure::References;
|
||||||
use crate::scope::Scope;
|
use crate::scope::Scope;
|
||||||
use inlinable_string::InlinableString;
|
use inlinable_string::InlinableString;
|
||||||
|
@ -1525,3 +1525,130 @@ pub fn unescape_char(escaped: &EscapedChar) -> char {
|
||||||
Newline => '\n',
|
Newline => '\n',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn free_variables(expr: &Expr) -> MutSet<Symbol> {
|
||||||
|
use Expr::*;
|
||||||
|
|
||||||
|
let mut stack = vec![expr.clone()];
|
||||||
|
|
||||||
|
let mut bound = MutSet::default();
|
||||||
|
let mut used: std::collections::HashSet<roc_module::symbol::Symbol, _> = MutSet::default();
|
||||||
|
|
||||||
|
while let Some(expr) = stack.pop() {
|
||||||
|
match expr {
|
||||||
|
Num(_, _) | Int(_, _) | Float(_, _) | Str(_) | EmptyRecord | RuntimeError(_) => {}
|
||||||
|
|
||||||
|
Var(s) => {
|
||||||
|
used.insert(s);
|
||||||
|
}
|
||||||
|
List { loc_elems, .. } => {
|
||||||
|
for e in loc_elems {
|
||||||
|
stack.push(e.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
When {
|
||||||
|
loc_cond, branches, ..
|
||||||
|
} => {
|
||||||
|
stack.push(loc_cond.value);
|
||||||
|
|
||||||
|
for branch in branches {
|
||||||
|
stack.push(branch.value.value);
|
||||||
|
|
||||||
|
if let Some(guard) = branch.guard {
|
||||||
|
stack.push(guard.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
bound.extend(
|
||||||
|
branch
|
||||||
|
.patterns
|
||||||
|
.iter()
|
||||||
|
.map(|t| symbols_from_pattern(&t.value).into_iter())
|
||||||
|
.flatten(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
If {
|
||||||
|
branches,
|
||||||
|
final_else,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
for (cond, then) in branches {
|
||||||
|
stack.push(cond.value);
|
||||||
|
stack.push(then.value);
|
||||||
|
}
|
||||||
|
stack.push(final_else.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
LetRec(defs, cont, _, _) => {
|
||||||
|
stack.push(cont.value);
|
||||||
|
|
||||||
|
for def in defs {
|
||||||
|
bound.extend(symbols_from_pattern(&def.loc_pattern.value));
|
||||||
|
stack.push(def.loc_expr.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LetNonRec(def, cont, _, _) => {
|
||||||
|
stack.push(cont.value);
|
||||||
|
|
||||||
|
bound.extend(symbols_from_pattern(&def.loc_pattern.value));
|
||||||
|
stack.push(def.loc_expr.value);
|
||||||
|
}
|
||||||
|
Call(boxed, args, _) => {
|
||||||
|
let (_, function, _, _) = *boxed;
|
||||||
|
|
||||||
|
stack.push(function.value);
|
||||||
|
for (_, arg) in args {
|
||||||
|
stack.push(arg.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RunLowLevel { args, .. } => {
|
||||||
|
for (_, arg) in args {
|
||||||
|
stack.push(arg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Closure {
|
||||||
|
arguments: args,
|
||||||
|
loc_body: boxed_body,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
bound.extend(
|
||||||
|
args.iter()
|
||||||
|
.map(|t| symbols_from_pattern(&t.1.value).into_iter())
|
||||||
|
.flatten(),
|
||||||
|
);
|
||||||
|
stack.push(boxed_body.value);
|
||||||
|
}
|
||||||
|
Record { fields, .. } => {
|
||||||
|
for (_, field) in fields {
|
||||||
|
stack.push(field.loc_expr.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Update {
|
||||||
|
symbol, updates, ..
|
||||||
|
} => {
|
||||||
|
used.insert(symbol);
|
||||||
|
for (_, field) in updates {
|
||||||
|
stack.push(field.loc_expr.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Access { loc_expr, .. } => {
|
||||||
|
stack.push(loc_expr.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
Accessor { .. } => {}
|
||||||
|
|
||||||
|
Tag { arguments, .. } => {
|
||||||
|
for (_, arg) in arguments {
|
||||||
|
stack.push(arg.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for b in bound {
|
||||||
|
used.remove(&b);
|
||||||
|
}
|
||||||
|
|
||||||
|
used
|
||||||
|
}
|
||||||
|
|
|
@ -12,6 +12,8 @@ roc_module = { path = "../module" }
|
||||||
roc_types = { path = "../types" }
|
roc_types = { path = "../types" }
|
||||||
roc_can = { path = "../can" }
|
roc_can = { path = "../can" }
|
||||||
roc_unify = { path = "../unify" }
|
roc_unify = { path = "../unify" }
|
||||||
|
roc_constrain = { path = "../constrain" }
|
||||||
|
roc_solve = { path = "../solve" }
|
||||||
roc_problem = { path = "../problem" }
|
roc_problem = { path = "../problem" }
|
||||||
ven_pretty = { path = "../../vendor/pretty" }
|
ven_pretty = { path = "../../vendor/pretty" }
|
||||||
bumpalo = { version = "3.2", features = ["collections"] }
|
bumpalo = { version = "3.2", features = ["collections"] }
|
||||||
|
|
306
compiler/mono/src/closures.rs
Normal file
306
compiler/mono/src/closures.rs
Normal file
|
@ -0,0 +1,306 @@
|
||||||
|
use roc_can::constraint::Constraint;
|
||||||
|
use roc_can::expected::Expected;
|
||||||
|
use roc_can::expr::Expr;
|
||||||
|
use roc_can::pattern::symbols_from_pattern;
|
||||||
|
use roc_collections::all::MutSet;
|
||||||
|
use roc_constrain::expr::exists;
|
||||||
|
use roc_module::symbol::Symbol;
|
||||||
|
use roc_region::all::Region;
|
||||||
|
use roc_types::subs::{Subs, VarStore};
|
||||||
|
use roc_types::types::{Category, Type};
|
||||||
|
|
||||||
|
pub fn infer_closure_size(expr: &Expr, mut subs: Subs) -> Subs {
|
||||||
|
use roc_solve::solve;
|
||||||
|
|
||||||
|
let mut var_store = VarStore::new_from_subs(&mut subs);
|
||||||
|
|
||||||
|
let env = solve::Env::default();
|
||||||
|
let mut problems = Vec::new();
|
||||||
|
let constraint = generate_constraint(expr, &mut var_store);
|
||||||
|
|
||||||
|
let (solved_subs, _new_env) = solve::run(&env, &mut problems, subs, &constraint);
|
||||||
|
|
||||||
|
debug_assert_eq!(problems.len(), 0);
|
||||||
|
|
||||||
|
solved_subs.0
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn free_variables(expr: &Expr) -> MutSet<Symbol> {
|
||||||
|
use Expr::*;
|
||||||
|
|
||||||
|
let mut stack = vec![expr.clone()];
|
||||||
|
|
||||||
|
let mut bound = MutSet::default();
|
||||||
|
let mut used: std::collections::HashSet<roc_module::symbol::Symbol, _> = MutSet::default();
|
||||||
|
|
||||||
|
while let Some(expr) = stack.pop() {
|
||||||
|
match expr {
|
||||||
|
Num(_, _) | Int(_, _) | Float(_, _) | Str(_) | EmptyRecord | RuntimeError(_) => {}
|
||||||
|
|
||||||
|
Var(s) => {
|
||||||
|
used.insert(s);
|
||||||
|
}
|
||||||
|
List { loc_elems, .. } => {
|
||||||
|
for e in loc_elems {
|
||||||
|
stack.push(e.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
When {
|
||||||
|
loc_cond, branches, ..
|
||||||
|
} => {
|
||||||
|
stack.push(loc_cond.value);
|
||||||
|
|
||||||
|
for branch in branches {
|
||||||
|
stack.push(branch.value.value);
|
||||||
|
|
||||||
|
if let Some(guard) = branch.guard {
|
||||||
|
stack.push(guard.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
bound.extend(
|
||||||
|
branch
|
||||||
|
.patterns
|
||||||
|
.iter()
|
||||||
|
.map(|t| symbols_from_pattern(&t.value).into_iter())
|
||||||
|
.flatten(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
If {
|
||||||
|
branches,
|
||||||
|
final_else,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
for (cond, then) in branches {
|
||||||
|
stack.push(cond.value);
|
||||||
|
stack.push(then.value);
|
||||||
|
}
|
||||||
|
stack.push(final_else.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
LetRec(defs, cont, _, _) => {
|
||||||
|
stack.push(cont.value);
|
||||||
|
|
||||||
|
for def in defs {
|
||||||
|
bound.extend(symbols_from_pattern(&def.loc_pattern.value));
|
||||||
|
stack.push(def.loc_expr.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LetNonRec(def, cont, _, _) => {
|
||||||
|
stack.push(cont.value);
|
||||||
|
|
||||||
|
bound.extend(symbols_from_pattern(&def.loc_pattern.value));
|
||||||
|
stack.push(def.loc_expr.value);
|
||||||
|
}
|
||||||
|
Call(boxed, args, _) => {
|
||||||
|
let (_, function, _, _) = *boxed;
|
||||||
|
|
||||||
|
stack.push(function.value);
|
||||||
|
for (_, arg) in args {
|
||||||
|
stack.push(arg.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RunLowLevel { args, .. } => {
|
||||||
|
for (_, arg) in args {
|
||||||
|
stack.push(arg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Closure {
|
||||||
|
arguments: args,
|
||||||
|
loc_body: boxed_body,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
bound.extend(
|
||||||
|
args.iter()
|
||||||
|
.map(|t| symbols_from_pattern(&t.1.value).into_iter())
|
||||||
|
.flatten(),
|
||||||
|
);
|
||||||
|
stack.push(boxed_body.value);
|
||||||
|
}
|
||||||
|
Record { fields, .. } => {
|
||||||
|
for (_, field) in fields {
|
||||||
|
stack.push(field.loc_expr.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Update {
|
||||||
|
symbol, updates, ..
|
||||||
|
} => {
|
||||||
|
used.insert(symbol);
|
||||||
|
for (_, field) in updates {
|
||||||
|
stack.push(field.loc_expr.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Access { loc_expr, .. } => {
|
||||||
|
stack.push(loc_expr.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
Accessor { .. } => {}
|
||||||
|
|
||||||
|
Tag { arguments, .. } => {
|
||||||
|
for (_, arg) in arguments {
|
||||||
|
stack.push(arg.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for b in bound {
|
||||||
|
used.remove(&b);
|
||||||
|
}
|
||||||
|
|
||||||
|
used
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_constraint(expr: &Expr, var_store: &mut VarStore) -> Constraint {
|
||||||
|
let mut constraints = Vec::new();
|
||||||
|
generate_constraints_help(expr, var_store, &mut constraints);
|
||||||
|
Constraint::And(constraints)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_constraints_help(
|
||||||
|
expr: &Expr,
|
||||||
|
var_store: &mut VarStore,
|
||||||
|
constraints: &mut Vec<Constraint>,
|
||||||
|
) {
|
||||||
|
use Expr::*;
|
||||||
|
|
||||||
|
match expr {
|
||||||
|
Num(_, _) | Int(_, _) | Float(_, _) | Str(_) | EmptyRecord | RuntimeError(_) => {}
|
||||||
|
|
||||||
|
Var(_) => {}
|
||||||
|
List { loc_elems, .. } => {
|
||||||
|
for e in loc_elems {
|
||||||
|
generate_constraints_help(&e.value, var_store, constraints);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
When {
|
||||||
|
loc_cond, branches, ..
|
||||||
|
} => {
|
||||||
|
generate_constraints_help(&loc_cond.value, var_store, constraints);
|
||||||
|
|
||||||
|
for branch in branches {
|
||||||
|
generate_constraints_help(&branch.value.value, var_store, constraints);
|
||||||
|
|
||||||
|
if let Some(guard) = &branch.guard {
|
||||||
|
generate_constraints_help(&guard.value, var_store, constraints);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
If {
|
||||||
|
branches,
|
||||||
|
final_else,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
for (cond, then) in branches {
|
||||||
|
generate_constraints_help(&cond.value, var_store, constraints);
|
||||||
|
generate_constraints_help(&then.value, var_store, constraints);
|
||||||
|
}
|
||||||
|
generate_constraints_help(&final_else.value, var_store, constraints);
|
||||||
|
}
|
||||||
|
|
||||||
|
LetRec(defs, cont, _, _) => {
|
||||||
|
generate_constraints_help(&cont.value, var_store, constraints);
|
||||||
|
|
||||||
|
for def in defs {
|
||||||
|
generate_constraints_help(&def.loc_expr.value, var_store, constraints);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LetNonRec(def, cont, _, _) => {
|
||||||
|
generate_constraints_help(&cont.value, var_store, constraints);
|
||||||
|
|
||||||
|
generate_constraints_help(&def.loc_expr.value, var_store, constraints);
|
||||||
|
}
|
||||||
|
Call(boxed, args, _) => {
|
||||||
|
let (_, function, _, _) = &**boxed;
|
||||||
|
|
||||||
|
generate_constraints_help(&function.value, var_store, constraints);
|
||||||
|
for (_, arg) in args {
|
||||||
|
generate_constraints_help(&arg.value, var_store, constraints);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RunLowLevel { args, .. } => {
|
||||||
|
for (_, arg) in args {
|
||||||
|
generate_constraints_help(&arg, var_store, constraints);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Closure {
|
||||||
|
arguments: _,
|
||||||
|
closure_type: closure_var,
|
||||||
|
loc_body: boxed_body,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
let mut cons = Vec::new();
|
||||||
|
let mut variables = Vec::new();
|
||||||
|
|
||||||
|
let closed_over_symbols = MutSet::default();
|
||||||
|
let closure_ext_var = var_store.fresh();
|
||||||
|
let closure_var = *closure_var;
|
||||||
|
|
||||||
|
variables.push(closure_ext_var);
|
||||||
|
// TODO unsure about including this one
|
||||||
|
variables.push(closure_var);
|
||||||
|
|
||||||
|
let mut tag_arguments = Vec::with_capacity(closed_over_symbols.len());
|
||||||
|
for symbol in closed_over_symbols {
|
||||||
|
let var = var_store.fresh();
|
||||||
|
variables.push(var);
|
||||||
|
tag_arguments.push(Type::Variable(var));
|
||||||
|
|
||||||
|
let region = Region::zero();
|
||||||
|
let expected = Expected::NoExpectation(Type::Variable(var));
|
||||||
|
let lookup = Constraint::Lookup(symbol, expected, region);
|
||||||
|
cons.push(lookup);
|
||||||
|
}
|
||||||
|
|
||||||
|
let tag_name_string = format!("Closure_{}", closure_var.index());
|
||||||
|
let tag_name = roc_module::ident::TagName::Global(tag_name_string.into());
|
||||||
|
let expected_type = Type::TagUnion(
|
||||||
|
vec![(tag_name, tag_arguments)],
|
||||||
|
Box::new(Type::Variable(closure_ext_var)),
|
||||||
|
);
|
||||||
|
|
||||||
|
// constrain this closures's size to the type we just created
|
||||||
|
let expected = Expected::NoExpectation(expected_type);
|
||||||
|
let category = Category::ClosureSize;
|
||||||
|
let region = boxed_body.region;
|
||||||
|
let equality = Constraint::Eq(Type::Variable(closure_var), expected, category, region);
|
||||||
|
|
||||||
|
cons.push(equality);
|
||||||
|
|
||||||
|
// generate constraints for nested closures
|
||||||
|
let mut inner_constraints = Vec::new();
|
||||||
|
generate_constraints_help(&boxed_body.value, var_store, &mut inner_constraints);
|
||||||
|
|
||||||
|
cons.push(Constraint::And(inner_constraints));
|
||||||
|
|
||||||
|
let constraint = exists(variables, Constraint::And(cons));
|
||||||
|
|
||||||
|
constraints.push(constraint);
|
||||||
|
}
|
||||||
|
Record { fields, .. } => {
|
||||||
|
for (_, field) in fields {
|
||||||
|
generate_constraints_help(&field.loc_expr.value, var_store, constraints);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Update {
|
||||||
|
symbol: _, updates, ..
|
||||||
|
} => {
|
||||||
|
for (_, field) in updates {
|
||||||
|
generate_constraints_help(&field.loc_expr.value, var_store, constraints);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Access { loc_expr, .. } => {
|
||||||
|
generate_constraints_help(&loc_expr.value, var_store, constraints);
|
||||||
|
}
|
||||||
|
|
||||||
|
Accessor { .. } => {}
|
||||||
|
|
||||||
|
Tag { arguments, .. } => {
|
||||||
|
for (_, arg) in arguments {
|
||||||
|
generate_constraints_help(&arg.value, var_store, constraints);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -12,6 +12,7 @@
|
||||||
#![allow(clippy::large_enum_variant)]
|
#![allow(clippy::large_enum_variant)]
|
||||||
|
|
||||||
pub mod borrow;
|
pub mod borrow;
|
||||||
|
pub mod closures;
|
||||||
pub mod inc_dec;
|
pub mod inc_dec;
|
||||||
pub mod ir;
|
pub mod ir;
|
||||||
pub mod layout;
|
pub mod layout;
|
||||||
|
|
|
@ -882,6 +882,11 @@ fn add_category<'b>(
|
||||||
|
|
||||||
Lambda => alloc.concat(vec![this_is, alloc.text(" an anonymous function of type:")]),
|
Lambda => alloc.concat(vec![this_is, alloc.text(" an anonymous function of type:")]),
|
||||||
|
|
||||||
|
ClosureSize => alloc.concat(vec![
|
||||||
|
this_is,
|
||||||
|
alloc.text(" the closure size of a function of type:"),
|
||||||
|
]),
|
||||||
|
|
||||||
TagApply {
|
TagApply {
|
||||||
tag_name: TagName::Global(name),
|
tag_name: TagName::Global(name),
|
||||||
args_count: 0,
|
args_count: 0,
|
||||||
|
|
|
@ -70,7 +70,7 @@ pub enum TypeError {
|
||||||
BadType(roc_types::types::Problem),
|
BadType(roc_types::types::Problem),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug, Default)]
|
||||||
pub struct Env {
|
pub struct Env {
|
||||||
pub vars_by_symbol: SendMap<Symbol, Variable>,
|
pub vars_by_symbol: SendMap<Symbol, Variable>,
|
||||||
pub aliases: MutMap<Symbol, Alias>,
|
pub aliases: MutMap<Symbol, Alias>,
|
||||||
|
|
|
@ -72,6 +72,13 @@ impl VarStore {
|
||||||
VarStore { next: next_var.0 }
|
VarStore { next: next_var.0 }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn new_from_subs(subs: &Subs) -> Self {
|
||||||
|
let next_var = subs.utable.len() as u32;
|
||||||
|
debug_assert!(next_var >= Variable::FIRST_USER_SPACE_VAR.0);
|
||||||
|
|
||||||
|
VarStore { next: next_var }
|
||||||
|
}
|
||||||
|
|
||||||
pub fn fresh(&mut self) -> Variable {
|
pub fn fresh(&mut self) -> Variable {
|
||||||
// Increment the counter and return the value it had before it was incremented.
|
// Increment the counter and return the value it had before it was incremented.
|
||||||
let answer = self.next;
|
let answer = self.next;
|
||||||
|
@ -163,6 +170,10 @@ impl Variable {
|
||||||
pub unsafe fn unsafe_test_debug_variable(v: u32) -> Self {
|
pub unsafe fn unsafe_test_debug_variable(v: u32) -> Self {
|
||||||
Variable(v)
|
Variable(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn index(&self) -> u32 {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Into<OptVariable> for Variable {
|
impl Into<OptVariable> for Variable {
|
||||||
|
|
|
@ -942,6 +942,7 @@ pub enum Category {
|
||||||
},
|
},
|
||||||
Lambda,
|
Lambda,
|
||||||
Uniqueness,
|
Uniqueness,
|
||||||
|
ClosureSize,
|
||||||
StrInterpolation,
|
StrInterpolation,
|
||||||
|
|
||||||
// storing variables in the ast
|
// storing variables in the ast
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue