From 05d1f28e8358e3fd2d38c770de5a8a53f94cf9fb Mon Sep 17 00:00:00 2001 From: Folkert Date: Fri, 2 Oct 2020 20:46:47 +0200 Subject: [PATCH] add logic to generate/solve closure size constraints --- compiler/can/src/expr.rs | 129 ++++++++++- compiler/mono/Cargo.toml | 2 + compiler/mono/src/closures.rs | 306 +++++++++++++++++++++++++++ compiler/mono/src/lib.rs | 1 + compiler/reporting/src/error/type.rs | 5 + compiler/solve/src/solve.rs | 2 +- compiler/types/src/subs.rs | 11 + compiler/types/src/types.rs | 1 + 8 files changed, 455 insertions(+), 2 deletions(-) create mode 100644 compiler/mono/src/closures.rs diff --git a/compiler/can/src/expr.rs b/compiler/can/src/expr.rs index aef77617ba..b4e73bebe7 100644 --- a/compiler/can/src/expr.rs +++ b/compiler/can/src/expr.rs @@ -6,7 +6,7 @@ use crate::num::{ finish_parsing_base, finish_parsing_float, finish_parsing_int, float_expr_from_result, int_expr_from_result, num_expr_from_result, }; -use crate::pattern::{canonicalize_pattern, Pattern}; +use crate::pattern::{canonicalize_pattern, symbols_from_pattern, Pattern}; use crate::procedure::References; use crate::scope::Scope; use inlinable_string::InlinableString; @@ -1525,3 +1525,130 @@ pub fn unescape_char(escaped: &EscapedChar) -> char { Newline => '\n', } } + +pub fn free_variables(expr: &Expr) -> MutSet { + use Expr::*; + + let mut stack = vec![expr.clone()]; + + let mut bound = MutSet::default(); + let mut used: std::collections::HashSet = MutSet::default(); + + while let Some(expr) = stack.pop() { + match expr { + Num(_, _) | Int(_, _) | Float(_, _) | Str(_) | EmptyRecord | RuntimeError(_) => {} + + Var(s) => { + used.insert(s); + } + List { loc_elems, .. } => { + for e in loc_elems { + stack.push(e.value); + } + } + When { + loc_cond, branches, .. + } => { + stack.push(loc_cond.value); + + for branch in branches { + stack.push(branch.value.value); + + if let Some(guard) = branch.guard { + stack.push(guard.value); + } + + bound.extend( + branch + .patterns + .iter() + .map(|t| symbols_from_pattern(&t.value).into_iter()) + .flatten(), + ); + } + } + If { + branches, + final_else, + .. + } => { + for (cond, then) in branches { + stack.push(cond.value); + stack.push(then.value); + } + stack.push(final_else.value); + } + + LetRec(defs, cont, _, _) => { + stack.push(cont.value); + + for def in defs { + bound.extend(symbols_from_pattern(&def.loc_pattern.value)); + stack.push(def.loc_expr.value); + } + } + LetNonRec(def, cont, _, _) => { + stack.push(cont.value); + + bound.extend(symbols_from_pattern(&def.loc_pattern.value)); + stack.push(def.loc_expr.value); + } + Call(boxed, args, _) => { + let (_, function, _, _) = *boxed; + + stack.push(function.value); + for (_, arg) in args { + stack.push(arg.value); + } + } + RunLowLevel { args, .. } => { + for (_, arg) in args { + stack.push(arg); + } + } + + Closure { + arguments: args, + loc_body: boxed_body, + .. + } => { + bound.extend( + args.iter() + .map(|t| symbols_from_pattern(&t.1.value).into_iter()) + .flatten(), + ); + stack.push(boxed_body.value); + } + Record { fields, .. } => { + for (_, field) in fields { + stack.push(field.loc_expr.value); + } + } + Update { + symbol, updates, .. + } => { + used.insert(symbol); + for (_, field) in updates { + stack.push(field.loc_expr.value); + } + } + Access { loc_expr, .. } => { + stack.push(loc_expr.value); + } + + Accessor { .. } => {} + + Tag { arguments, .. } => { + for (_, arg) in arguments { + stack.push(arg.value); + } + } + } + } + + for b in bound { + used.remove(&b); + } + + used +} diff --git a/compiler/mono/Cargo.toml b/compiler/mono/Cargo.toml index 5301d0bdc3..079e0f0da0 100644 --- a/compiler/mono/Cargo.toml +++ b/compiler/mono/Cargo.toml @@ -12,6 +12,8 @@ roc_module = { path = "../module" } roc_types = { path = "../types" } roc_can = { path = "../can" } roc_unify = { path = "../unify" } +roc_constrain = { path = "../constrain" } +roc_solve = { path = "../solve" } roc_problem = { path = "../problem" } ven_pretty = { path = "../../vendor/pretty" } bumpalo = { version = "3.2", features = ["collections"] } diff --git a/compiler/mono/src/closures.rs b/compiler/mono/src/closures.rs new file mode 100644 index 0000000000..139e2f9c2f --- /dev/null +++ b/compiler/mono/src/closures.rs @@ -0,0 +1,306 @@ +use roc_can::constraint::Constraint; +use roc_can::expected::Expected; +use roc_can::expr::Expr; +use roc_can::pattern::symbols_from_pattern; +use roc_collections::all::MutSet; +use roc_constrain::expr::exists; +use roc_module::symbol::Symbol; +use roc_region::all::Region; +use roc_types::subs::{Subs, VarStore}; +use roc_types::types::{Category, Type}; + +pub fn infer_closure_size(expr: &Expr, mut subs: Subs) -> Subs { + use roc_solve::solve; + + let mut var_store = VarStore::new_from_subs(&mut subs); + + let env = solve::Env::default(); + let mut problems = Vec::new(); + let constraint = generate_constraint(expr, &mut var_store); + + let (solved_subs, _new_env) = solve::run(&env, &mut problems, subs, &constraint); + + debug_assert_eq!(problems.len(), 0); + + solved_subs.0 +} + +pub fn free_variables(expr: &Expr) -> MutSet { + use Expr::*; + + let mut stack = vec![expr.clone()]; + + let mut bound = MutSet::default(); + let mut used: std::collections::HashSet = MutSet::default(); + + while let Some(expr) = stack.pop() { + match expr { + Num(_, _) | Int(_, _) | Float(_, _) | Str(_) | EmptyRecord | RuntimeError(_) => {} + + Var(s) => { + used.insert(s); + } + List { loc_elems, .. } => { + for e in loc_elems { + stack.push(e.value); + } + } + When { + loc_cond, branches, .. + } => { + stack.push(loc_cond.value); + + for branch in branches { + stack.push(branch.value.value); + + if let Some(guard) = branch.guard { + stack.push(guard.value); + } + + bound.extend( + branch + .patterns + .iter() + .map(|t| symbols_from_pattern(&t.value).into_iter()) + .flatten(), + ); + } + } + If { + branches, + final_else, + .. + } => { + for (cond, then) in branches { + stack.push(cond.value); + stack.push(then.value); + } + stack.push(final_else.value); + } + + LetRec(defs, cont, _, _) => { + stack.push(cont.value); + + for def in defs { + bound.extend(symbols_from_pattern(&def.loc_pattern.value)); + stack.push(def.loc_expr.value); + } + } + LetNonRec(def, cont, _, _) => { + stack.push(cont.value); + + bound.extend(symbols_from_pattern(&def.loc_pattern.value)); + stack.push(def.loc_expr.value); + } + Call(boxed, args, _) => { + let (_, function, _, _) = *boxed; + + stack.push(function.value); + for (_, arg) in args { + stack.push(arg.value); + } + } + RunLowLevel { args, .. } => { + for (_, arg) in args { + stack.push(arg); + } + } + + Closure { + arguments: args, + loc_body: boxed_body, + .. + } => { + bound.extend( + args.iter() + .map(|t| symbols_from_pattern(&t.1.value).into_iter()) + .flatten(), + ); + stack.push(boxed_body.value); + } + Record { fields, .. } => { + for (_, field) in fields { + stack.push(field.loc_expr.value); + } + } + Update { + symbol, updates, .. + } => { + used.insert(symbol); + for (_, field) in updates { + stack.push(field.loc_expr.value); + } + } + Access { loc_expr, .. } => { + stack.push(loc_expr.value); + } + + Accessor { .. } => {} + + Tag { arguments, .. } => { + for (_, arg) in arguments { + stack.push(arg.value); + } + } + } + } + + for b in bound { + used.remove(&b); + } + + used +} + +pub fn generate_constraint(expr: &Expr, var_store: &mut VarStore) -> Constraint { + let mut constraints = Vec::new(); + generate_constraints_help(expr, var_store, &mut constraints); + Constraint::And(constraints) +} + +pub fn generate_constraints_help( + expr: &Expr, + var_store: &mut VarStore, + constraints: &mut Vec, +) { + use Expr::*; + + match expr { + Num(_, _) | Int(_, _) | Float(_, _) | Str(_) | EmptyRecord | RuntimeError(_) => {} + + Var(_) => {} + List { loc_elems, .. } => { + for e in loc_elems { + generate_constraints_help(&e.value, var_store, constraints); + } + } + When { + loc_cond, branches, .. + } => { + generate_constraints_help(&loc_cond.value, var_store, constraints); + + for branch in branches { + generate_constraints_help(&branch.value.value, var_store, constraints); + + if let Some(guard) = &branch.guard { + generate_constraints_help(&guard.value, var_store, constraints); + } + } + } + If { + branches, + final_else, + .. + } => { + for (cond, then) in branches { + generate_constraints_help(&cond.value, var_store, constraints); + generate_constraints_help(&then.value, var_store, constraints); + } + generate_constraints_help(&final_else.value, var_store, constraints); + } + + LetRec(defs, cont, _, _) => { + generate_constraints_help(&cont.value, var_store, constraints); + + for def in defs { + generate_constraints_help(&def.loc_expr.value, var_store, constraints); + } + } + LetNonRec(def, cont, _, _) => { + generate_constraints_help(&cont.value, var_store, constraints); + + generate_constraints_help(&def.loc_expr.value, var_store, constraints); + } + Call(boxed, args, _) => { + let (_, function, _, _) = &**boxed; + + generate_constraints_help(&function.value, var_store, constraints); + for (_, arg) in args { + generate_constraints_help(&arg.value, var_store, constraints); + } + } + RunLowLevel { args, .. } => { + for (_, arg) in args { + generate_constraints_help(&arg, var_store, constraints); + } + } + + Closure { + arguments: _, + closure_type: closure_var, + loc_body: boxed_body, + .. + } => { + let mut cons = Vec::new(); + let mut variables = Vec::new(); + + let closed_over_symbols = MutSet::default(); + let closure_ext_var = var_store.fresh(); + let closure_var = *closure_var; + + variables.push(closure_ext_var); + // TODO unsure about including this one + variables.push(closure_var); + + let mut tag_arguments = Vec::with_capacity(closed_over_symbols.len()); + for symbol in closed_over_symbols { + let var = var_store.fresh(); + variables.push(var); + tag_arguments.push(Type::Variable(var)); + + let region = Region::zero(); + let expected = Expected::NoExpectation(Type::Variable(var)); + let lookup = Constraint::Lookup(symbol, expected, region); + cons.push(lookup); + } + + let tag_name_string = format!("Closure_{}", closure_var.index()); + let tag_name = roc_module::ident::TagName::Global(tag_name_string.into()); + let expected_type = Type::TagUnion( + vec![(tag_name, tag_arguments)], + Box::new(Type::Variable(closure_ext_var)), + ); + + // constrain this closures's size to the type we just created + let expected = Expected::NoExpectation(expected_type); + let category = Category::ClosureSize; + let region = boxed_body.region; + let equality = Constraint::Eq(Type::Variable(closure_var), expected, category, region); + + cons.push(equality); + + // generate constraints for nested closures + let mut inner_constraints = Vec::new(); + generate_constraints_help(&boxed_body.value, var_store, &mut inner_constraints); + + cons.push(Constraint::And(inner_constraints)); + + let constraint = exists(variables, Constraint::And(cons)); + + constraints.push(constraint); + } + Record { fields, .. } => { + for (_, field) in fields { + generate_constraints_help(&field.loc_expr.value, var_store, constraints); + } + } + Update { + symbol: _, updates, .. + } => { + for (_, field) in updates { + generate_constraints_help(&field.loc_expr.value, var_store, constraints); + } + } + Access { loc_expr, .. } => { + generate_constraints_help(&loc_expr.value, var_store, constraints); + } + + Accessor { .. } => {} + + Tag { arguments, .. } => { + for (_, arg) in arguments { + generate_constraints_help(&arg.value, var_store, constraints); + } + } + } +} diff --git a/compiler/mono/src/lib.rs b/compiler/mono/src/lib.rs index e44b3d7fae..e4876e4547 100644 --- a/compiler/mono/src/lib.rs +++ b/compiler/mono/src/lib.rs @@ -12,6 +12,7 @@ #![allow(clippy::large_enum_variant)] pub mod borrow; +pub mod closures; pub mod inc_dec; pub mod ir; pub mod layout; diff --git a/compiler/reporting/src/error/type.rs b/compiler/reporting/src/error/type.rs index 798cb38e2c..dccd69e372 100644 --- a/compiler/reporting/src/error/type.rs +++ b/compiler/reporting/src/error/type.rs @@ -882,6 +882,11 @@ fn add_category<'b>( Lambda => alloc.concat(vec![this_is, alloc.text(" an anonymous function of type:")]), + ClosureSize => alloc.concat(vec![ + this_is, + alloc.text(" the closure size of a function of type:"), + ]), + TagApply { tag_name: TagName::Global(name), args_count: 0, diff --git a/compiler/solve/src/solve.rs b/compiler/solve/src/solve.rs index 506fd6bee4..3c9cc0841a 100644 --- a/compiler/solve/src/solve.rs +++ b/compiler/solve/src/solve.rs @@ -70,7 +70,7 @@ pub enum TypeError { BadType(roc_types::types::Problem), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct Env { pub vars_by_symbol: SendMap, pub aliases: MutMap, diff --git a/compiler/types/src/subs.rs b/compiler/types/src/subs.rs index e57fe32f9e..7ab457a423 100644 --- a/compiler/types/src/subs.rs +++ b/compiler/types/src/subs.rs @@ -72,6 +72,13 @@ impl VarStore { VarStore { next: next_var.0 } } + pub fn new_from_subs(subs: &Subs) -> Self { + let next_var = subs.utable.len() as u32; + debug_assert!(next_var >= Variable::FIRST_USER_SPACE_VAR.0); + + VarStore { next: next_var } + } + pub fn fresh(&mut self) -> Variable { // Increment the counter and return the value it had before it was incremented. let answer = self.next; @@ -163,6 +170,10 @@ impl Variable { pub unsafe fn unsafe_test_debug_variable(v: u32) -> Self { Variable(v) } + + pub fn index(&self) -> u32 { + self.0 + } } impl Into for Variable { diff --git a/compiler/types/src/types.rs b/compiler/types/src/types.rs index 5fb19fe21b..07ba293f05 100644 --- a/compiler/types/src/types.rs +++ b/compiler/types/src/types.rs @@ -942,6 +942,7 @@ pub enum Category { }, Lambda, Uniqueness, + ClosureSize, StrInterpolation, // storing variables in the ast