mirror of
https://github.com/roc-lang/roc.git
synced 2025-08-04 12:18:19 +00:00
Specialize if
This commit is contained in:
parent
2ca829aaa8
commit
0585f32039
7 changed files with 186 additions and 11 deletions
|
@ -194,6 +194,38 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
|
|||
|
||||
MonoExpr::Struct(slice)
|
||||
}
|
||||
Expr::If {
|
||||
cond_var: _,
|
||||
branch_var,
|
||||
branches,
|
||||
final_else,
|
||||
} => {
|
||||
let branch_type = mono_from_var(*branch_var);
|
||||
|
||||
let mono_final_else = self.to_mono_expr(&final_else.value);
|
||||
let final_else = self.mono_exprs.add(mono_final_else, final_else.region);
|
||||
|
||||
let mut branch_pairs: Vec<((MonoExpr, Region), (MonoExpr, Region))> =
|
||||
Vec::with_capacity_in(branches.len(), self.arena);
|
||||
|
||||
for (cond, body) in branches {
|
||||
let mono_cond = self.to_mono_expr(&cond.value);
|
||||
let mono_body = self.to_mono_expr(&body.value);
|
||||
|
||||
branch_pairs.push(((mono_cond, cond.region), (mono_body, body.region)));
|
||||
}
|
||||
|
||||
let branches = self.mono_exprs.extend_pairs(branch_pairs.into_iter());
|
||||
|
||||
MonoExpr::If {
|
||||
branch_type,
|
||||
branches,
|
||||
final_else,
|
||||
}
|
||||
}
|
||||
Expr::Var(symbol, var) | Expr::ParamsVar { symbol, var, .. } => {
|
||||
MonoExpr::Lookup(*symbol, mono_from_var(*var))
|
||||
}
|
||||
// Expr::Call((fn_var, fn_expr, capture_var, ret_var), args, called_via) => {
|
||||
// let opt_ret_type = mono_from_var(*var);
|
||||
|
||||
|
@ -258,7 +290,6 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
|
|||
// })
|
||||
// }
|
||||
// }
|
||||
Expr::Var(symbol, var) => MonoExpr::Lookup(*symbol, mono_from_var(*var)),
|
||||
// Expr::LetNonRec(def, loc) => {
|
||||
// let expr = self.to_mono_expr(def.loc_expr.value, stmts)?;
|
||||
// let todo = (); // TODO if this is an underscore pattern and we're doing a fn call, convert it to Stmt::CallVoid
|
||||
|
@ -298,12 +329,6 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
|
|||
// branches_cond_var,
|
||||
// exhaustive,
|
||||
// } => todo!(),
|
||||
// Expr::If {
|
||||
// cond_var,
|
||||
// branch_var,
|
||||
// branches,
|
||||
// final_else,
|
||||
// } => todo!(),
|
||||
// Expr::Call(_, vec, called_via) => todo!(),
|
||||
// Expr::RunLowLevel { op, args, ret_var } => todo!(),
|
||||
// Expr::ForeignCall {
|
||||
|
|
|
@ -6,7 +6,7 @@ use roc_can::expr::Recursive;
|
|||
use roc_module::low_level::LowLevel;
|
||||
use roc_module::symbol::Symbol;
|
||||
use roc_region::all::Region;
|
||||
use soa::{Index, NonEmptySlice, Slice, Slice2, Slice3};
|
||||
use soa::{Index, NonEmptySlice, PairSlice, Slice, Slice2, Slice3};
|
||||
use std::iter;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
|
@ -146,6 +146,50 @@ impl MonoExprs {
|
|||
|
||||
Slice::new(start as u32, len as u16)
|
||||
}
|
||||
|
||||
pub fn iter_pair_slice(
|
||||
&self,
|
||||
exprs: PairSlice<MonoExpr>,
|
||||
) -> impl Iterator<Item = (&MonoExpr, &MonoExpr)> {
|
||||
exprs.indices_iter().map(|(index_a, index_b)| {
|
||||
debug_assert!(
|
||||
self.exprs.len() > index_a && self.exprs.len() > index_b,
|
||||
"A Slice index was not found in MonoExprs. This should never happen!"
|
||||
);
|
||||
|
||||
// Safety: we should only ever hand out MonoExprId slices that are valid indices into here.
|
||||
unsafe {
|
||||
(
|
||||
self.exprs.get_unchecked(index_a),
|
||||
self.exprs.get_unchecked(index_b),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn extend_pairs(
|
||||
&mut self,
|
||||
exprs: impl Iterator<Item = ((MonoExpr, Region), (MonoExpr, Region))>,
|
||||
) -> PairSlice<MonoExpr> {
|
||||
let start = self.exprs.len();
|
||||
|
||||
let additional = exprs.size_hint().0 * 2;
|
||||
self.exprs.reserve(additional);
|
||||
self.regions.reserve(additional);
|
||||
|
||||
let mut pairs = 0;
|
||||
|
||||
for ((expr_a, region_a), (expr_b, region_b)) in exprs {
|
||||
self.exprs.push(expr_a);
|
||||
self.exprs.push(expr_b);
|
||||
self.regions.push(region_a);
|
||||
self.regions.push(region_b);
|
||||
|
||||
pairs += 1;
|
||||
}
|
||||
|
||||
PairSlice::new(start as u32, pairs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
|
@ -329,6 +373,12 @@ pub enum MonoExpr {
|
|||
final_expr: MonoExprId,
|
||||
},
|
||||
|
||||
If {
|
||||
branch_type: MonoTypeId,
|
||||
branches: PairSlice<MonoExpr>,
|
||||
final_else: MonoExprId,
|
||||
},
|
||||
|
||||
CompilerBug(Problem),
|
||||
}
|
||||
|
||||
|
|
|
@ -167,8 +167,38 @@ fn dbg_mono_expr_help<'a>(
|
|||
MonoExpr::Unit => {
|
||||
write!(buf, "{{}}").unwrap();
|
||||
}
|
||||
MonoExpr::If {
|
||||
branch_type: _,
|
||||
branches,
|
||||
final_else,
|
||||
} => {
|
||||
write!(buf, "If(",).unwrap();
|
||||
|
||||
for (index, (cond, branch)) in mono_exprs.iter_pair_slice(*branches).enumerate() {
|
||||
if index > 0 {
|
||||
write!(buf, ", ").unwrap();
|
||||
}
|
||||
|
||||
dbg_mono_expr_help(arena, mono_exprs, interns, cond, buf);
|
||||
write!(buf, " -> ").unwrap();
|
||||
dbg_mono_expr_help(arena, mono_exprs, interns, branch, buf);
|
||||
write!(buf, ")").unwrap();
|
||||
}
|
||||
|
||||
write!(buf, ", ").unwrap();
|
||||
dbg_mono_expr_help(
|
||||
arena,
|
||||
mono_exprs,
|
||||
interns,
|
||||
mono_exprs.get_expr(*final_else),
|
||||
buf,
|
||||
);
|
||||
write!(buf, ")").unwrap();
|
||||
}
|
||||
MonoExpr::Lookup(ident, _mono_type_id) => {
|
||||
write!(buf, "{:?}", ident).unwrap();
|
||||
}
|
||||
// MonoExpr::List { elem_type, elems } => todo!(),
|
||||
// MonoExpr::Lookup(symbol, mono_type_id) => todo!(),
|
||||
// MonoExpr::ParameterizedLookup {
|
||||
// name,
|
||||
// lookup_type,
|
||||
|
|
36
crates/compiler/specialize_types/tests/specialize_if.rs
Normal file
36
crates/compiler/specialize_types/tests/specialize_if.rs
Normal file
|
@ -0,0 +1,36 @@
|
|||
#[macro_use]
|
||||
extern crate pretty_assertions;
|
||||
|
||||
#[cfg(test)]
|
||||
mod helpers;
|
||||
|
||||
#[cfg(test)]
|
||||
mod specialize_structs {
|
||||
use crate::helpers::expect_mono_expr_str;
|
||||
|
||||
#[test]
|
||||
fn single_branch() {
|
||||
let cond = "Bool.true";
|
||||
let then = 42;
|
||||
let else_ = 0;
|
||||
|
||||
expect_mono_expr_str(
|
||||
format!("if {cond} then {then} else {else_}"),
|
||||
format!("If(`Bool.true` -> Number(I8(42))), Number(I8(0)))"),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_branches() {
|
||||
let cond1 = "Bool.false";
|
||||
let then1 = 256;
|
||||
let cond2 = "Bool.true";
|
||||
let then2 = 24;
|
||||
let then_else = 0;
|
||||
|
||||
expect_mono_expr_str(
|
||||
format!("if {cond1} then {then1} else if {cond2} then {then2} else {then_else}"),
|
||||
format!("If(`Bool.false` -> Number(I16(256))), `Bool.true` -> Number(I16(24))), Number(I16(0)))"),
|
||||
);
|
||||
}
|
||||
}
|
|
@ -7,7 +7,7 @@ mod helpers;
|
|||
#[cfg(test)]
|
||||
mod specialize_primitives {
|
||||
use roc_module::symbol::Symbol;
|
||||
use roc_specialize_types::{MonoExpr, MonoType, MonoTypeId, Number, Primitive};
|
||||
use roc_specialize_types::{MonoExpr, MonoTypeId, Number};
|
||||
|
||||
use super::helpers::{expect_mono_expr, expect_mono_expr_with_interns};
|
||||
|
||||
|
|
|
@ -8,6 +8,6 @@ mod soa_slice3;
|
|||
|
||||
pub use either_index::*;
|
||||
pub use soa_index::*;
|
||||
pub use soa_slice::{NonEmptySlice, Slice};
|
||||
pub use soa_slice::{NonEmptySlice, PairSlice, Slice};
|
||||
pub use soa_slice2::Slice2;
|
||||
pub use soa_slice3::Slice3;
|
||||
|
|
|
@ -274,3 +274,37 @@ impl<T> IntoIterator for NonEmptySlice<T> {
|
|||
self.inner.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
/// Like `Slice`, but for pairs of `T`
|
||||
#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Copy, Clone)]
|
||||
pub struct PairSlice<T>(Slice<T>);
|
||||
|
||||
impl<T> PairSlice<T> {
|
||||
pub const fn start(self) -> u32 {
|
||||
self.0.start()
|
||||
}
|
||||
|
||||
pub const fn empty() -> Self {
|
||||
Self(Slice::empty())
|
||||
}
|
||||
|
||||
pub const fn len(&self) -> usize {
|
||||
self.0.len() / 2
|
||||
}
|
||||
|
||||
pub fn indices_iter(&self) -> impl Iterator<Item = (usize, usize)> {
|
||||
(self.0.start as usize..(self.0.start as usize + self.0.length as usize))
|
||||
.step_by(2)
|
||||
.map(|i| (i, i + 1))
|
||||
}
|
||||
|
||||
pub const fn new(start: u32, length: u16) -> Self {
|
||||
Self(Slice::new(start, length * 2))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Default for PairSlice<T> {
|
||||
fn default() -> Self {
|
||||
Self(Slice::default())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue