diff --git a/compiler/gen/tests/test_gen.rs b/compiler/gen/tests/test_gen.rs index f1c4a69c13..744bfdd189 100644 --- a/compiler/gen/tests/test_gen.rs +++ b/compiler/gen/tests/test_gen.rs @@ -991,6 +991,28 @@ mod test_gen { ); } + #[test] + fn when_on_enum() { + assert_evals_to!( + indoc!( + r#" + Fruit : [ Apple, Orange, Banana ] + + apple : Fruit + apple = Apple + + when apple is + Apple -> 1 + Banana -> 2 + Orange -> 3 + _ -> 4 + "# + ), + 1, + i64 + ); + } + // #[test] // fn basic_record() { // assert_evals_to!( diff --git a/compiler/mono/src/expr.rs b/compiler/mono/src/expr.rs index 28655d4a07..348e2896b1 100644 --- a/compiler/mono/src/expr.rs +++ b/compiler/mono/src/expr.rs @@ -2,11 +2,10 @@ use crate::layout::{Builtin, Layout}; use bumpalo::collections::Vec; use bumpalo::Bump; use roc_can; -use roc_can::pattern::Pattern; use roc_collections::all::{MutMap, MutSet}; -use roc_module::ident::{Lowercase, TagName}; +use roc_module::ident::{Ident, Lowercase, TagName}; use roc_module::symbol::{IdentIds, ModuleId, Symbol}; -use roc_region::all::Located; +use roc_region::all::{Located, Region}; use roc_types::subs::{Content, ContentHash, FlatType, Subs, Variable}; #[derive(Clone, Debug, PartialEq, Default)] @@ -335,7 +334,7 @@ fn pattern_to_when<'a>( (env.fresh_symbol(), body) } - AppliedTag(_, _, _) | RecordDestructure(_, _) | Shadowed(_, _) | UnsupportedPattern(_) => { + AppliedTag {..} | RecordDestructure {..} | Shadowed(_, _) | UnsupportedPattern(_) => { let symbol = env.fresh_symbol(); let wrapped_body = When { @@ -405,9 +404,10 @@ fn from_can<'a>( } // If it wasn't specifically an Identifier & Closure, proceed as normal. + let mono_pattern = from_can_pattern(env, loc_pattern.value); store_pattern( env, - loc_pattern.value, + mono_pattern, loc_expr.value, def.expr_var, procs, @@ -743,7 +743,7 @@ fn store_pattern<'a>( procs: &mut Procs<'a>, stored: &mut Vec<'a, (Symbol, Layout<'a>, Expr<'a>)>, ) { - use roc_can::pattern::Pattern::*; + use Pattern::*; let layout = match Layout::from_var(env.arena, var, env.subs, env.pointer_size) { Ok(layout) => layout, @@ -792,7 +792,7 @@ fn from_can_when<'a>( )>, procs: &mut Procs<'a>, ) -> Expr<'a> { - use roc_can::pattern::Pattern::*; + use Pattern::*; match branches.len() { 0 => { @@ -807,9 +807,10 @@ fn from_can_when<'a>( let mut stored = Vec::with_capacity_in(1, arena); let (loc_when_pattern, loc_branch) = branches.into_iter().next().unwrap(); + let mono_pattern = from_can_pattern(env, loc_when_pattern.value); store_pattern( env, - loc_when_pattern.value, + mono_pattern, loc_cond.value, cond_var, procs, @@ -824,52 +825,18 @@ fn from_can_when<'a>( // A when-expression with exactly 2 branches compiles to a Cond. let arena = env.arena; let mut iter = branches.into_iter(); - let (loc_when_pat1, loc_then) = iter.next().unwrap(); - let (loc_when_pat2, loc_else) = iter.next().unwrap(); + let (can_loc_when_pat1, loc_then) = iter.next().unwrap(); + let (can_loc_when_pat2, loc_else) = iter.next().unwrap(); + + let when_pat1 = from_can_pattern(env, can_loc_when_pat1.value); + let when_pat2 = from_can_pattern(env, can_loc_when_pat2.value); let cond_layout = Layout::Builtin(Builtin::Bool( TagName::Global("False".into()), TagName::Global("True".into()), )); - match (&loc_when_pat1.value, &loc_when_pat2.value) { - (NumLiteral(var, num), Underscore) => { - let cond_lhs = from_can(env, loc_cond.value, procs, None); - - let (fn_symbol, builtin, cond_rhs_expr) = match to_int_or_float(env.subs, *var) - { - IntOrFloat::IntType => { - (Symbol::INT_EQ_I64, Builtin::Int64, Expr::Int(*num)) - } - IntOrFloat::FloatType => { - (Symbol::FLOAT_EQ, Builtin::Float64, Expr::Float(*num as f64)) - } - }; - let cond_rhs = cond_rhs_expr; - - let cond = arena.alloc(Expr::CallByName( - fn_symbol, - arena.alloc([ - (cond_lhs, Layout::Builtin(builtin.clone())), - (cond_rhs, Layout::Builtin(builtin)), - ]), - )); - - let pass = arena.alloc(from_can(env, loc_then.value, procs, None)); - let fail = arena.alloc(from_can(env, loc_else.value, procs, None)); - let ret_layout = Layout::from_var(arena, expr_var, env.subs, env.pointer_size) - .unwrap_or_else(|err| { - panic!("TODO turn this into a RuntimeError {:?}", err) - }); - - Expr::Cond { - cond_layout, - cond, - pass, - fail, - ret_layout, - } - } + match (&when_pat1, &when_pat2) { (IntLiteral(int), Underscore) => { let cond_lhs = from_can(env, loc_cond.value, procs, None); let cond_rhs = Expr::Int(*int); @@ -946,6 +913,8 @@ fn from_can_when<'a>( // TODO we can also convert floats to integer representations. let is_switchable = match layout { Layout::Builtin(Builtin::Int64) => true, + Layout::Builtin(Builtin::Bool(_, _)) => true, + Layout::Builtin(Builtin::Byte(_)) => true, _ => false, }; @@ -959,43 +928,31 @@ fn from_can_when<'a>( for (loc_when_pat, loc_expr) in branches { let mono_expr = from_can(env, loc_expr.value, procs, None); + let when_pat = from_can_pattern(env, loc_when_pat.value); - match &loc_when_pat.value { - NumLiteral(var, num) => { - // This is jumpable iff it's an int - match to_int_or_float(env.subs, *var) { - IntOrFloat::IntType => { - jumpable_branches.push((*num as u64, mono_expr)); - } - IntOrFloat::FloatType => { - // The type checker should have converted these mismatches into RuntimeErrors already! - if cfg!(debug_assertions) { - panic!("A type mismatch in a pattern was not converted to a runtime error: {:?}", loc_when_pat); - } else { - unreachable!(); - } - } - }; - } + match &when_pat { IntLiteral(int) => { // Switch only compares the condition to the // alternatives based on their bit patterns, // so casting from i64 to u64 makes no difference here. jumpable_branches.push((*int as u64, mono_expr)); } - Identifier(_symbol) => { + BitLiteral(v) => jumpable_branches.push((*v as u64, mono_expr)), + EnumLiteral(v) => jumpable_branches.push((*v as u64, mono_expr)), + Identifier(symbol) => { // Since this is an ident, it must be // the last pattern in the `when`. // We can safely treat this like an `_` // except that we need to wrap this branch // in a `Store` so the identifier is in scope! - opt_default_branch = Some(arena.alloc(if true { - // Using `if true` for this TODO panic to avoid a warning - panic!("TODO wrap this expr in an Expr::Store: {:?}", mono_expr) - } else { - mono_expr - })); + // TODO does this evaluate `cond` twice? + let mono_with_store = Expr::Store( + arena.alloc([(*symbol, layout.clone(), cond.clone())]), + arena.alloc(mono_expr), + ); + + opt_default_branch = Some(arena.alloc(mono_with_store)); } Underscore => { // We should always have exactly one default branch! @@ -1016,7 +973,7 @@ fn from_can_when<'a>( | FloatLiteral(_) => { // The type checker should have converted these mismatches into RuntimeErrors already! if cfg!(debug_assertions) { - panic!("A type mismatch in a pattern was not converted to a runtime error: {:?}", loc_when_pat); + panic!("A type mismatch in a pattern was not converted to a runtime error: {:?}", when_pat); } else { unreachable!(); } @@ -1035,6 +992,7 @@ fn from_can_when<'a>( .unwrap_or_else(|err| { panic!("TODO turn cond_layout into a RuntimeError {:?}", err) }); + let ret_layout = Layout::from_var(arena, expr_var, env.subs, env.pointer_size) .unwrap_or_else(|err| { panic!("TODO turn ret_layout into a RuntimeError {:?}", err) @@ -1196,3 +1154,105 @@ fn specialize_proc_body<'a>( Some(proc) } + +/// A pattern, including possible problems (e.g. shadowing) so that +/// codegen can generate a runtime error if this pattern is reached. +#[derive(Clone, Debug, PartialEq)] +pub enum Pattern<'a> { + Identifier(Symbol), + AppliedTag(TagName, Vec<'a, Pattern<'a>>, Layout<'a>), + BitLiteral(bool), + EnumLiteral(u8), + IntLiteral(i64), + FloatLiteral(f64), + StrLiteral(Box), + RecordDestructure(Vec<'a, RecordDestruct<'a>>, Layout<'a>), + Underscore, + + // Runtime Exceptions + Shadowed(Region, Located), + // Example: (5 = 1 + 2) is an unsupported pattern in an assignment; Int patterns aren't allowed in assignments! + UnsupportedPattern(Region), +} + +#[derive(Clone, Debug, PartialEq)] +pub struct RecordDestruct<'a> { + pub label: Lowercase, + pub symbol: Symbol, + pub guard: Option>, +} + +fn from_can_pattern<'a>( + env: &mut Env<'a, '_>, + can_pattern: roc_can::pattern::Pattern, +) -> Pattern<'a> { + use roc_can::pattern::Pattern::*; + match can_pattern { + Underscore => Pattern::Underscore, + Identifier(symbol) => Pattern::Identifier(symbol), + IntLiteral(v) => Pattern::IntLiteral(v), + FloatLiteral(v) => Pattern::FloatLiteral(v), + StrLiteral(v) => Pattern::StrLiteral(v), + Shadowed(region, ident) => Pattern::Shadowed(region, ident), + UnsupportedPattern(region) => Pattern::UnsupportedPattern(region), + + NumLiteral(var, num) => match to_int_or_float(env.subs, var) { + IntOrFloat::IntType => Pattern::IntLiteral(num), + IntOrFloat::FloatType => Pattern::FloatLiteral(num as f64), + }, + + AppliedTag { + whole_var, + tag_name, + arguments, + .. + } => match Layout::from_var(env.arena, whole_var, env.subs, env.pointer_size) { + Ok(Layout::Builtin(Builtin::Bool(_bottom, top))) => { + Pattern::BitLiteral(tag_name == top) + } + Ok(Layout::Builtin(Builtin::Byte(conversion))) => match conversion.get(&tag_name) { + Some(index) => Pattern::EnumLiteral(*index), + None => unreachable!("Tag must be in its own type"), + }, + Ok(layout) => { + let mut mono_args = Vec::with_capacity_in(arguments.len(), env.arena); + for (_, loc_pat) in arguments { + mono_args.push(from_can_pattern(env, loc_pat.value)); + } + + Pattern::AppliedTag(tag_name, mono_args, layout) + } + Err(()) => panic!("Invalid layout"), + }, + + RecordDestructure { + whole_var, + destructs, + .. + } => match Layout::from_var(env.arena, whole_var, env.subs, env.pointer_size) { + Ok(layout) => { + let mut mono_destructs = Vec::with_capacity_in(destructs.len(), env.arena); + for loc_rec_des in destructs { + mono_destructs.push(from_can_record_destruct(env, loc_rec_des.value)); + } + + Pattern::RecordDestructure(mono_destructs, layout) + } + Err(()) => panic!("Invalid layout"), + }, + } +} + +fn from_can_record_destruct<'a>( + env: &mut Env<'a, '_>, + can_rd: roc_can::pattern::RecordDestruct, +) -> RecordDestruct<'a> { + RecordDestruct { + label: can_rd.label, + symbol: can_rd.symbol, + guard: match can_rd.guard { + None => None, + Some((_, loc_pattern)) => Some(from_can_pattern(env, loc_pattern.value)), + }, + } +}