Specialize if

This commit is contained in:
Agus Zubiaga 2024-12-11 18:09:54 -03:00
parent 2ca829aaa8
commit 0585f32039
No known key found for this signature in database
7 changed files with 186 additions and 11 deletions

View file

@ -194,6 +194,38 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
MonoExpr::Struct(slice)
}
Expr::If {
cond_var: _,
branch_var,
branches,
final_else,
} => {
let branch_type = mono_from_var(*branch_var);
let mono_final_else = self.to_mono_expr(&final_else.value);
let final_else = self.mono_exprs.add(mono_final_else, final_else.region);
let mut branch_pairs: Vec<((MonoExpr, Region), (MonoExpr, Region))> =
Vec::with_capacity_in(branches.len(), self.arena);
for (cond, body) in branches {
let mono_cond = self.to_mono_expr(&cond.value);
let mono_body = self.to_mono_expr(&body.value);
branch_pairs.push(((mono_cond, cond.region), (mono_body, body.region)));
}
let branches = self.mono_exprs.extend_pairs(branch_pairs.into_iter());
MonoExpr::If {
branch_type,
branches,
final_else,
}
}
Expr::Var(symbol, var) | Expr::ParamsVar { symbol, var, .. } => {
MonoExpr::Lookup(*symbol, mono_from_var(*var))
}
// Expr::Call((fn_var, fn_expr, capture_var, ret_var), args, called_via) => {
// let opt_ret_type = mono_from_var(*var);
@ -258,7 +290,6 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
// })
// }
// }
Expr::Var(symbol, var) => MonoExpr::Lookup(*symbol, mono_from_var(*var)),
// Expr::LetNonRec(def, loc) => {
// let expr = self.to_mono_expr(def.loc_expr.value, stmts)?;
// let todo = (); // TODO if this is an underscore pattern and we're doing a fn call, convert it to Stmt::CallVoid
@ -298,12 +329,6 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
// branches_cond_var,
// exhaustive,
// } => todo!(),
// Expr::If {
// cond_var,
// branch_var,
// branches,
// final_else,
// } => todo!(),
// Expr::Call(_, vec, called_via) => todo!(),
// Expr::RunLowLevel { op, args, ret_var } => todo!(),
// Expr::ForeignCall {

View file

@ -6,7 +6,7 @@ use roc_can::expr::Recursive;
use roc_module::low_level::LowLevel;
use roc_module::symbol::Symbol;
use roc_region::all::Region;
use soa::{Index, NonEmptySlice, Slice, Slice2, Slice3};
use soa::{Index, NonEmptySlice, PairSlice, Slice, Slice2, Slice3};
use std::iter;
#[derive(Clone, Copy, Debug, PartialEq)]
@ -146,6 +146,50 @@ impl MonoExprs {
Slice::new(start as u32, len as u16)
}
pub fn iter_pair_slice(
&self,
exprs: PairSlice<MonoExpr>,
) -> impl Iterator<Item = (&MonoExpr, &MonoExpr)> {
exprs.indices_iter().map(|(index_a, index_b)| {
debug_assert!(
self.exprs.len() > index_a && self.exprs.len() > index_b,
"A Slice index was not found in MonoExprs. This should never happen!"
);
// Safety: we should only ever hand out MonoExprId slices that are valid indices into here.
unsafe {
(
self.exprs.get_unchecked(index_a),
self.exprs.get_unchecked(index_b),
)
}
})
}
pub fn extend_pairs(
&mut self,
exprs: impl Iterator<Item = ((MonoExpr, Region), (MonoExpr, Region))>,
) -> PairSlice<MonoExpr> {
let start = self.exprs.len();
let additional = exprs.size_hint().0 * 2;
self.exprs.reserve(additional);
self.regions.reserve(additional);
let mut pairs = 0;
for ((expr_a, region_a), (expr_b, region_b)) in exprs {
self.exprs.push(expr_a);
self.exprs.push(expr_b);
self.regions.push(region_a);
self.regions.push(region_b);
pairs += 1;
}
PairSlice::new(start as u32, pairs)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@ -329,6 +373,12 @@ pub enum MonoExpr {
final_expr: MonoExprId,
},
If {
branch_type: MonoTypeId,
branches: PairSlice<MonoExpr>,
final_else: MonoExprId,
},
CompilerBug(Problem),
}

View file

@ -167,8 +167,38 @@ fn dbg_mono_expr_help<'a>(
MonoExpr::Unit => {
write!(buf, "{{}}").unwrap();
}
MonoExpr::If {
branch_type: _,
branches,
final_else,
} => {
write!(buf, "If(",).unwrap();
for (index, (cond, branch)) in mono_exprs.iter_pair_slice(*branches).enumerate() {
if index > 0 {
write!(buf, ", ").unwrap();
}
dbg_mono_expr_help(arena, mono_exprs, interns, cond, buf);
write!(buf, " -> ").unwrap();
dbg_mono_expr_help(arena, mono_exprs, interns, branch, buf);
write!(buf, ")").unwrap();
}
write!(buf, ", ").unwrap();
dbg_mono_expr_help(
arena,
mono_exprs,
interns,
mono_exprs.get_expr(*final_else),
buf,
);
write!(buf, ")").unwrap();
}
MonoExpr::Lookup(ident, _mono_type_id) => {
write!(buf, "{:?}", ident).unwrap();
}
// MonoExpr::List { elem_type, elems } => todo!(),
// MonoExpr::Lookup(symbol, mono_type_id) => todo!(),
// MonoExpr::ParameterizedLookup {
// name,
// lookup_type,

View file

@ -0,0 +1,36 @@
#[macro_use]
extern crate pretty_assertions;
#[cfg(test)]
mod helpers;
#[cfg(test)]
mod specialize_structs {
use crate::helpers::expect_mono_expr_str;
#[test]
fn single_branch() {
let cond = "Bool.true";
let then = 42;
let else_ = 0;
expect_mono_expr_str(
format!("if {cond} then {then} else {else_}"),
format!("If(`Bool.true` -> Number(I8(42))), Number(I8(0)))"),
);
}
#[test]
fn multiple_branches() {
let cond1 = "Bool.false";
let then1 = 256;
let cond2 = "Bool.true";
let then2 = 24;
let then_else = 0;
expect_mono_expr_str(
format!("if {cond1} then {then1} else if {cond2} then {then2} else {then_else}"),
format!("If(`Bool.false` -> Number(I16(256))), `Bool.true` -> Number(I16(24))), Number(I16(0)))"),
);
}
}

View file

@ -7,7 +7,7 @@ mod helpers;
#[cfg(test)]
mod specialize_primitives {
use roc_module::symbol::Symbol;
use roc_specialize_types::{MonoExpr, MonoType, MonoTypeId, Number, Primitive};
use roc_specialize_types::{MonoExpr, MonoTypeId, Number};
use super::helpers::{expect_mono_expr, expect_mono_expr_with_interns};

View file

@ -8,6 +8,6 @@ mod soa_slice3;
pub use either_index::*;
pub use soa_index::*;
pub use soa_slice::{NonEmptySlice, Slice};
pub use soa_slice::{NonEmptySlice, PairSlice, Slice};
pub use soa_slice2::Slice2;
pub use soa_slice3::Slice3;

View file

@ -274,3 +274,37 @@ impl<T> IntoIterator for NonEmptySlice<T> {
self.inner.into_iter()
}
}
/// Like `Slice`, but for pairs of `T`
#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Copy, Clone)]
pub struct PairSlice<T>(Slice<T>);
impl<T> PairSlice<T> {
pub const fn start(self) -> u32 {
self.0.start()
}
pub const fn empty() -> Self {
Self(Slice::empty())
}
pub const fn len(&self) -> usize {
self.0.len() / 2
}
pub fn indices_iter(&self) -> impl Iterator<Item = (usize, usize)> {
(self.0.start as usize..(self.0.start as usize + self.0.length as usize))
.step_by(2)
.map(|i| (i, i + 1))
}
pub const fn new(start: u32, length: u16) -> Self {
Self(Slice::new(start, length * 2))
}
}
impl<T> Default for PairSlice<T> {
fn default() -> Self {
Self(Slice::default())
}
}