mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-28 22:34:45 +00:00
add logic to generate/solve closure size constraints
This commit is contained in:
parent
bb6f36ad28
commit
05d1f28e83
8 changed files with 455 additions and 2 deletions
|
@ -6,7 +6,7 @@ use crate::num::{
|
|||
finish_parsing_base, finish_parsing_float, finish_parsing_int, float_expr_from_result,
|
||||
int_expr_from_result, num_expr_from_result,
|
||||
};
|
||||
use crate::pattern::{canonicalize_pattern, Pattern};
|
||||
use crate::pattern::{canonicalize_pattern, symbols_from_pattern, Pattern};
|
||||
use crate::procedure::References;
|
||||
use crate::scope::Scope;
|
||||
use inlinable_string::InlinableString;
|
||||
|
@ -1525,3 +1525,130 @@ pub fn unescape_char(escaped: &EscapedChar) -> char {
|
|||
Newline => '\n',
|
||||
}
|
||||
}
|
||||
|
||||
pub fn free_variables(expr: &Expr) -> MutSet<Symbol> {
|
||||
use Expr::*;
|
||||
|
||||
let mut stack = vec![expr.clone()];
|
||||
|
||||
let mut bound = MutSet::default();
|
||||
let mut used: std::collections::HashSet<roc_module::symbol::Symbol, _> = MutSet::default();
|
||||
|
||||
while let Some(expr) = stack.pop() {
|
||||
match expr {
|
||||
Num(_, _) | Int(_, _) | Float(_, _) | Str(_) | EmptyRecord | RuntimeError(_) => {}
|
||||
|
||||
Var(s) => {
|
||||
used.insert(s);
|
||||
}
|
||||
List { loc_elems, .. } => {
|
||||
for e in loc_elems {
|
||||
stack.push(e.value);
|
||||
}
|
||||
}
|
||||
When {
|
||||
loc_cond, branches, ..
|
||||
} => {
|
||||
stack.push(loc_cond.value);
|
||||
|
||||
for branch in branches {
|
||||
stack.push(branch.value.value);
|
||||
|
||||
if let Some(guard) = branch.guard {
|
||||
stack.push(guard.value);
|
||||
}
|
||||
|
||||
bound.extend(
|
||||
branch
|
||||
.patterns
|
||||
.iter()
|
||||
.map(|t| symbols_from_pattern(&t.value).into_iter())
|
||||
.flatten(),
|
||||
);
|
||||
}
|
||||
}
|
||||
If {
|
||||
branches,
|
||||
final_else,
|
||||
..
|
||||
} => {
|
||||
for (cond, then) in branches {
|
||||
stack.push(cond.value);
|
||||
stack.push(then.value);
|
||||
}
|
||||
stack.push(final_else.value);
|
||||
}
|
||||
|
||||
LetRec(defs, cont, _, _) => {
|
||||
stack.push(cont.value);
|
||||
|
||||
for def in defs {
|
||||
bound.extend(symbols_from_pattern(&def.loc_pattern.value));
|
||||
stack.push(def.loc_expr.value);
|
||||
}
|
||||
}
|
||||
LetNonRec(def, cont, _, _) => {
|
||||
stack.push(cont.value);
|
||||
|
||||
bound.extend(symbols_from_pattern(&def.loc_pattern.value));
|
||||
stack.push(def.loc_expr.value);
|
||||
}
|
||||
Call(boxed, args, _) => {
|
||||
let (_, function, _, _) = *boxed;
|
||||
|
||||
stack.push(function.value);
|
||||
for (_, arg) in args {
|
||||
stack.push(arg.value);
|
||||
}
|
||||
}
|
||||
RunLowLevel { args, .. } => {
|
||||
for (_, arg) in args {
|
||||
stack.push(arg);
|
||||
}
|
||||
}
|
||||
|
||||
Closure {
|
||||
arguments: args,
|
||||
loc_body: boxed_body,
|
||||
..
|
||||
} => {
|
||||
bound.extend(
|
||||
args.iter()
|
||||
.map(|t| symbols_from_pattern(&t.1.value).into_iter())
|
||||
.flatten(),
|
||||
);
|
||||
stack.push(boxed_body.value);
|
||||
}
|
||||
Record { fields, .. } => {
|
||||
for (_, field) in fields {
|
||||
stack.push(field.loc_expr.value);
|
||||
}
|
||||
}
|
||||
Update {
|
||||
symbol, updates, ..
|
||||
} => {
|
||||
used.insert(symbol);
|
||||
for (_, field) in updates {
|
||||
stack.push(field.loc_expr.value);
|
||||
}
|
||||
}
|
||||
Access { loc_expr, .. } => {
|
||||
stack.push(loc_expr.value);
|
||||
}
|
||||
|
||||
Accessor { .. } => {}
|
||||
|
||||
Tag { arguments, .. } => {
|
||||
for (_, arg) in arguments {
|
||||
stack.push(arg.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for b in bound {
|
||||
used.remove(&b);
|
||||
}
|
||||
|
||||
used
|
||||
}
|
||||
|
|
|
@ -12,6 +12,8 @@ roc_module = { path = "../module" }
|
|||
roc_types = { path = "../types" }
|
||||
roc_can = { path = "../can" }
|
||||
roc_unify = { path = "../unify" }
|
||||
roc_constrain = { path = "../constrain" }
|
||||
roc_solve = { path = "../solve" }
|
||||
roc_problem = { path = "../problem" }
|
||||
ven_pretty = { path = "../../vendor/pretty" }
|
||||
bumpalo = { version = "3.2", features = ["collections"] }
|
||||
|
|
306
compiler/mono/src/closures.rs
Normal file
306
compiler/mono/src/closures.rs
Normal file
|
@ -0,0 +1,306 @@
|
|||
use roc_can::constraint::Constraint;
|
||||
use roc_can::expected::Expected;
|
||||
use roc_can::expr::Expr;
|
||||
use roc_can::pattern::symbols_from_pattern;
|
||||
use roc_collections::all::MutSet;
|
||||
use roc_constrain::expr::exists;
|
||||
use roc_module::symbol::Symbol;
|
||||
use roc_region::all::Region;
|
||||
use roc_types::subs::{Subs, VarStore};
|
||||
use roc_types::types::{Category, Type};
|
||||
|
||||
pub fn infer_closure_size(expr: &Expr, mut subs: Subs) -> Subs {
|
||||
use roc_solve::solve;
|
||||
|
||||
let mut var_store = VarStore::new_from_subs(&mut subs);
|
||||
|
||||
let env = solve::Env::default();
|
||||
let mut problems = Vec::new();
|
||||
let constraint = generate_constraint(expr, &mut var_store);
|
||||
|
||||
let (solved_subs, _new_env) = solve::run(&env, &mut problems, subs, &constraint);
|
||||
|
||||
debug_assert_eq!(problems.len(), 0);
|
||||
|
||||
solved_subs.0
|
||||
}
|
||||
|
||||
pub fn free_variables(expr: &Expr) -> MutSet<Symbol> {
|
||||
use Expr::*;
|
||||
|
||||
let mut stack = vec![expr.clone()];
|
||||
|
||||
let mut bound = MutSet::default();
|
||||
let mut used: std::collections::HashSet<roc_module::symbol::Symbol, _> = MutSet::default();
|
||||
|
||||
while let Some(expr) = stack.pop() {
|
||||
match expr {
|
||||
Num(_, _) | Int(_, _) | Float(_, _) | Str(_) | EmptyRecord | RuntimeError(_) => {}
|
||||
|
||||
Var(s) => {
|
||||
used.insert(s);
|
||||
}
|
||||
List { loc_elems, .. } => {
|
||||
for e in loc_elems {
|
||||
stack.push(e.value);
|
||||
}
|
||||
}
|
||||
When {
|
||||
loc_cond, branches, ..
|
||||
} => {
|
||||
stack.push(loc_cond.value);
|
||||
|
||||
for branch in branches {
|
||||
stack.push(branch.value.value);
|
||||
|
||||
if let Some(guard) = branch.guard {
|
||||
stack.push(guard.value);
|
||||
}
|
||||
|
||||
bound.extend(
|
||||
branch
|
||||
.patterns
|
||||
.iter()
|
||||
.map(|t| symbols_from_pattern(&t.value).into_iter())
|
||||
.flatten(),
|
||||
);
|
||||
}
|
||||
}
|
||||
If {
|
||||
branches,
|
||||
final_else,
|
||||
..
|
||||
} => {
|
||||
for (cond, then) in branches {
|
||||
stack.push(cond.value);
|
||||
stack.push(then.value);
|
||||
}
|
||||
stack.push(final_else.value);
|
||||
}
|
||||
|
||||
LetRec(defs, cont, _, _) => {
|
||||
stack.push(cont.value);
|
||||
|
||||
for def in defs {
|
||||
bound.extend(symbols_from_pattern(&def.loc_pattern.value));
|
||||
stack.push(def.loc_expr.value);
|
||||
}
|
||||
}
|
||||
LetNonRec(def, cont, _, _) => {
|
||||
stack.push(cont.value);
|
||||
|
||||
bound.extend(symbols_from_pattern(&def.loc_pattern.value));
|
||||
stack.push(def.loc_expr.value);
|
||||
}
|
||||
Call(boxed, args, _) => {
|
||||
let (_, function, _, _) = *boxed;
|
||||
|
||||
stack.push(function.value);
|
||||
for (_, arg) in args {
|
||||
stack.push(arg.value);
|
||||
}
|
||||
}
|
||||
RunLowLevel { args, .. } => {
|
||||
for (_, arg) in args {
|
||||
stack.push(arg);
|
||||
}
|
||||
}
|
||||
|
||||
Closure {
|
||||
arguments: args,
|
||||
loc_body: boxed_body,
|
||||
..
|
||||
} => {
|
||||
bound.extend(
|
||||
args.iter()
|
||||
.map(|t| symbols_from_pattern(&t.1.value).into_iter())
|
||||
.flatten(),
|
||||
);
|
||||
stack.push(boxed_body.value);
|
||||
}
|
||||
Record { fields, .. } => {
|
||||
for (_, field) in fields {
|
||||
stack.push(field.loc_expr.value);
|
||||
}
|
||||
}
|
||||
Update {
|
||||
symbol, updates, ..
|
||||
} => {
|
||||
used.insert(symbol);
|
||||
for (_, field) in updates {
|
||||
stack.push(field.loc_expr.value);
|
||||
}
|
||||
}
|
||||
Access { loc_expr, .. } => {
|
||||
stack.push(loc_expr.value);
|
||||
}
|
||||
|
||||
Accessor { .. } => {}
|
||||
|
||||
Tag { arguments, .. } => {
|
||||
for (_, arg) in arguments {
|
||||
stack.push(arg.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for b in bound {
|
||||
used.remove(&b);
|
||||
}
|
||||
|
||||
used
|
||||
}
|
||||
|
||||
pub fn generate_constraint(expr: &Expr, var_store: &mut VarStore) -> Constraint {
|
||||
let mut constraints = Vec::new();
|
||||
generate_constraints_help(expr, var_store, &mut constraints);
|
||||
Constraint::And(constraints)
|
||||
}
|
||||
|
||||
pub fn generate_constraints_help(
|
||||
expr: &Expr,
|
||||
var_store: &mut VarStore,
|
||||
constraints: &mut Vec<Constraint>,
|
||||
) {
|
||||
use Expr::*;
|
||||
|
||||
match expr {
|
||||
Num(_, _) | Int(_, _) | Float(_, _) | Str(_) | EmptyRecord | RuntimeError(_) => {}
|
||||
|
||||
Var(_) => {}
|
||||
List { loc_elems, .. } => {
|
||||
for e in loc_elems {
|
||||
generate_constraints_help(&e.value, var_store, constraints);
|
||||
}
|
||||
}
|
||||
When {
|
||||
loc_cond, branches, ..
|
||||
} => {
|
||||
generate_constraints_help(&loc_cond.value, var_store, constraints);
|
||||
|
||||
for branch in branches {
|
||||
generate_constraints_help(&branch.value.value, var_store, constraints);
|
||||
|
||||
if let Some(guard) = &branch.guard {
|
||||
generate_constraints_help(&guard.value, var_store, constraints);
|
||||
}
|
||||
}
|
||||
}
|
||||
If {
|
||||
branches,
|
||||
final_else,
|
||||
..
|
||||
} => {
|
||||
for (cond, then) in branches {
|
||||
generate_constraints_help(&cond.value, var_store, constraints);
|
||||
generate_constraints_help(&then.value, var_store, constraints);
|
||||
}
|
||||
generate_constraints_help(&final_else.value, var_store, constraints);
|
||||
}
|
||||
|
||||
LetRec(defs, cont, _, _) => {
|
||||
generate_constraints_help(&cont.value, var_store, constraints);
|
||||
|
||||
for def in defs {
|
||||
generate_constraints_help(&def.loc_expr.value, var_store, constraints);
|
||||
}
|
||||
}
|
||||
LetNonRec(def, cont, _, _) => {
|
||||
generate_constraints_help(&cont.value, var_store, constraints);
|
||||
|
||||
generate_constraints_help(&def.loc_expr.value, var_store, constraints);
|
||||
}
|
||||
Call(boxed, args, _) => {
|
||||
let (_, function, _, _) = &**boxed;
|
||||
|
||||
generate_constraints_help(&function.value, var_store, constraints);
|
||||
for (_, arg) in args {
|
||||
generate_constraints_help(&arg.value, var_store, constraints);
|
||||
}
|
||||
}
|
||||
RunLowLevel { args, .. } => {
|
||||
for (_, arg) in args {
|
||||
generate_constraints_help(&arg, var_store, constraints);
|
||||
}
|
||||
}
|
||||
|
||||
Closure {
|
||||
arguments: _,
|
||||
closure_type: closure_var,
|
||||
loc_body: boxed_body,
|
||||
..
|
||||
} => {
|
||||
let mut cons = Vec::new();
|
||||
let mut variables = Vec::new();
|
||||
|
||||
let closed_over_symbols = MutSet::default();
|
||||
let closure_ext_var = var_store.fresh();
|
||||
let closure_var = *closure_var;
|
||||
|
||||
variables.push(closure_ext_var);
|
||||
// TODO unsure about including this one
|
||||
variables.push(closure_var);
|
||||
|
||||
let mut tag_arguments = Vec::with_capacity(closed_over_symbols.len());
|
||||
for symbol in closed_over_symbols {
|
||||
let var = var_store.fresh();
|
||||
variables.push(var);
|
||||
tag_arguments.push(Type::Variable(var));
|
||||
|
||||
let region = Region::zero();
|
||||
let expected = Expected::NoExpectation(Type::Variable(var));
|
||||
let lookup = Constraint::Lookup(symbol, expected, region);
|
||||
cons.push(lookup);
|
||||
}
|
||||
|
||||
let tag_name_string = format!("Closure_{}", closure_var.index());
|
||||
let tag_name = roc_module::ident::TagName::Global(tag_name_string.into());
|
||||
let expected_type = Type::TagUnion(
|
||||
vec![(tag_name, tag_arguments)],
|
||||
Box::new(Type::Variable(closure_ext_var)),
|
||||
);
|
||||
|
||||
// constrain this closures's size to the type we just created
|
||||
let expected = Expected::NoExpectation(expected_type);
|
||||
let category = Category::ClosureSize;
|
||||
let region = boxed_body.region;
|
||||
let equality = Constraint::Eq(Type::Variable(closure_var), expected, category, region);
|
||||
|
||||
cons.push(equality);
|
||||
|
||||
// generate constraints for nested closures
|
||||
let mut inner_constraints = Vec::new();
|
||||
generate_constraints_help(&boxed_body.value, var_store, &mut inner_constraints);
|
||||
|
||||
cons.push(Constraint::And(inner_constraints));
|
||||
|
||||
let constraint = exists(variables, Constraint::And(cons));
|
||||
|
||||
constraints.push(constraint);
|
||||
}
|
||||
Record { fields, .. } => {
|
||||
for (_, field) in fields {
|
||||
generate_constraints_help(&field.loc_expr.value, var_store, constraints);
|
||||
}
|
||||
}
|
||||
Update {
|
||||
symbol: _, updates, ..
|
||||
} => {
|
||||
for (_, field) in updates {
|
||||
generate_constraints_help(&field.loc_expr.value, var_store, constraints);
|
||||
}
|
||||
}
|
||||
Access { loc_expr, .. } => {
|
||||
generate_constraints_help(&loc_expr.value, var_store, constraints);
|
||||
}
|
||||
|
||||
Accessor { .. } => {}
|
||||
|
||||
Tag { arguments, .. } => {
|
||||
for (_, arg) in arguments {
|
||||
generate_constraints_help(&arg.value, var_store, constraints);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@
|
|||
#![allow(clippy::large_enum_variant)]
|
||||
|
||||
pub mod borrow;
|
||||
pub mod closures;
|
||||
pub mod inc_dec;
|
||||
pub mod ir;
|
||||
pub mod layout;
|
||||
|
|
|
@ -882,6 +882,11 @@ fn add_category<'b>(
|
|||
|
||||
Lambda => alloc.concat(vec![this_is, alloc.text(" an anonymous function of type:")]),
|
||||
|
||||
ClosureSize => alloc.concat(vec![
|
||||
this_is,
|
||||
alloc.text(" the closure size of a function of type:"),
|
||||
]),
|
||||
|
||||
TagApply {
|
||||
tag_name: TagName::Global(name),
|
||||
args_count: 0,
|
||||
|
|
|
@ -70,7 +70,7 @@ pub enum TypeError {
|
|||
BadType(roc_types::types::Problem),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct Env {
|
||||
pub vars_by_symbol: SendMap<Symbol, Variable>,
|
||||
pub aliases: MutMap<Symbol, Alias>,
|
||||
|
|
|
@ -72,6 +72,13 @@ impl VarStore {
|
|||
VarStore { next: next_var.0 }
|
||||
}
|
||||
|
||||
pub fn new_from_subs(subs: &Subs) -> Self {
|
||||
let next_var = subs.utable.len() as u32;
|
||||
debug_assert!(next_var >= Variable::FIRST_USER_SPACE_VAR.0);
|
||||
|
||||
VarStore { next: next_var }
|
||||
}
|
||||
|
||||
pub fn fresh(&mut self) -> Variable {
|
||||
// Increment the counter and return the value it had before it was incremented.
|
||||
let answer = self.next;
|
||||
|
@ -163,6 +170,10 @@ impl Variable {
|
|||
pub unsafe fn unsafe_test_debug_variable(v: u32) -> Self {
|
||||
Variable(v)
|
||||
}
|
||||
|
||||
pub fn index(&self) -> u32 {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<OptVariable> for Variable {
|
||||
|
|
|
@ -942,6 +942,7 @@ pub enum Category {
|
|||
},
|
||||
Lambda,
|
||||
Uniqueness,
|
||||
ClosureSize,
|
||||
StrInterpolation,
|
||||
|
||||
// storing variables in the ast
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue