diff --git a/compiler/mono/src/expr.rs b/compiler/mono/src/expr.rs index e9d42fd544..cb201c4a66 100644 --- a/compiler/mono/src/expr.rs +++ b/compiler/mono/src/expr.rs @@ -335,7 +335,7 @@ fn pattern_to_when<'a>( (env.fresh_symbol(), body) } - Shadowed(_, _) | UnsupportedPattern(_) => { + Shadowed(_, _) | UnsupportedPattern(_) => { // create the runtime error here, instead of delegating to When. // UnsupportedPattern should then never occcur in When panic!("TODO generate runtime error here"); @@ -411,7 +411,7 @@ fn from_can<'a>( } // If it wasn't specifically an Identifier & Closure, proceed as normal. - let mono_pattern = from_can_pattern(env, loc_pattern.value); + let mono_pattern = from_can_pattern(env, &loc_pattern.value); store_pattern( env, mono_pattern, @@ -829,7 +829,7 @@ fn from_can_when<'a>( let mut stored = Vec::with_capacity_in(1, arena); let (loc_when_pattern, loc_branch) = branches.into_iter().next().unwrap(); - let mono_pattern = from_can_pattern(env, loc_when_pattern.value); + let mono_pattern = from_can_pattern(env, &loc_when_pattern.value); store_pattern( env, mono_pattern, @@ -844,7 +844,12 @@ fn from_can_when<'a>( Expr::Store(stored.into_bump_slice(), arena.alloc(ret)) } 2 => { - let loc_branches: std::vec::Vec<_> = branches.iter().map(|v| v.0.clone()).collect(); + let loc_branches: std::vec::Vec<_> = branches + .iter() + .map(|(loc_branch, _)| { + Located::at(loc_branch.region, from_can_pattern(env, &loc_branch.value)) + }) + .collect(); match crate::pattern::check(Region::zero(), &loc_branches) { Ok(_) => {} @@ -857,8 +862,8 @@ fn from_can_when<'a>( let (can_loc_when_pat1, loc_then) = iter.next().unwrap(); let (can_loc_when_pat2, loc_else) = iter.next().unwrap(); - let when_pat1 = from_can_pattern(env, can_loc_when_pat1.value); - let when_pat2 = from_can_pattern(env, can_loc_when_pat2.value); + let when_pat1 = from_can_pattern(env, &can_loc_when_pat1.value); + let when_pat2 = from_can_pattern(env, &can_loc_when_pat2.value); let cond_layout = Layout::Builtin(Builtin::Bool( TagName::Global("False".into()), @@ -926,7 +931,12 @@ fn from_can_when<'a>( } } _ => { - let loc_branches: std::vec::Vec<_> = branches.iter().map(|v| v.0.clone()).collect(); + let loc_branches: std::vec::Vec<_> = branches + .iter() + .map(|(loc_branch, _)| { + Located::at(loc_branch.region, from_can_pattern(env, &loc_branch.value)) + }) + .collect(); match crate::pattern::check(Region::zero(), &loc_branches) { Ok(_) => {} @@ -965,7 +975,7 @@ fn from_can_when<'a>( let mut is_last = true; for (loc_when_pat, loc_expr) in branches.into_iter().rev() { let mono_expr = from_can(env, loc_expr.value, procs, None); - let when_pat = from_can_pattern(env, loc_when_pat.value); + let when_pat = from_can_pattern(env, &loc_when_pat.value); if is_last { opt_default_branch = match &when_pat { @@ -976,7 +986,7 @@ fn from_can_when<'a>( arena.alloc(mono_expr.clone()), ))) } - Shadowed(_region, _ident) => { + Shadowed(_region, _ident) => { panic!("TODO make runtime exception out of the branch"); } _ => Some(arena.alloc(mono_expr.clone())), @@ -992,7 +1002,9 @@ fn from_can_when<'a>( jumpable_branches.push((*int as u64, mono_expr)); } BitLiteral(v) => jumpable_branches.push((*v as u64, mono_expr)), - EnumLiteral { tag_id , .. } => jumpable_branches.push((*tag_id as u64, mono_expr)), + EnumLiteral { tag_id, .. } => { + jumpable_branches.push((*tag_id as u64, mono_expr)) + } Identifier(_) => { // store is handled above } @@ -1198,9 +1210,13 @@ pub enum Pattern<'a> { tag_name: TagName, arguments: Vec<'a, Pattern<'a>>, layout: Layout<'a>, + union: crate::pattern::Union, }, BitLiteral(bool), - EnumLiteral {tag_id: u8, enum_size: u8 }, + EnumLiteral { + tag_id: u8, + enum_size: u8, + }, IntLiteral(i64), FloatLiteral(f64), StrLiteral(Box), @@ -1222,21 +1238,21 @@ pub struct RecordDestruct<'a> { fn from_can_pattern<'a>( env: &mut Env<'a, '_>, - can_pattern: roc_can::pattern::Pattern, + can_pattern: &roc_can::pattern::Pattern, ) -> Pattern<'a> { use roc_can::pattern::Pattern::*; match can_pattern { Underscore => Pattern::Underscore, - Identifier(symbol) => Pattern::Identifier(symbol), - IntLiteral(v) => Pattern::IntLiteral(v), - FloatLiteral(v) => Pattern::FloatLiteral(v), - StrLiteral(v) => Pattern::StrLiteral(v), - Shadowed(region, ident) => Pattern::Shadowed(region, ident), - UnsupportedPattern(region) => Pattern::UnsupportedPattern(region), + Identifier(symbol) => Pattern::Identifier(*symbol), + IntLiteral(v) => Pattern::IntLiteral(*v), + FloatLiteral(v) => Pattern::FloatLiteral(*v), + StrLiteral(v) => Pattern::StrLiteral(v.clone()), + Shadowed(region, ident) => Pattern::Shadowed(*region, ident.clone()), + UnsupportedPattern(region) => Pattern::UnsupportedPattern(*region), - NumLiteral(var, num) => match to_int_or_float(env.subs, var) { - IntOrFloat::IntType => Pattern::IntLiteral(num), - IntOrFloat::FloatType => Pattern::FloatLiteral(num as f64), + NumLiteral(var, num) => match to_int_or_float(env.subs, *var) { + IntOrFloat::IntType => Pattern::IntLiteral(*num), + IntOrFloat::FloatType => Pattern::FloatLiteral(*num as f64), }, AppliedTag { @@ -1244,23 +1260,49 @@ fn from_can_pattern<'a>( tag_name, arguments, .. - } => match Layout::from_var(env.arena, whole_var, env.subs, env.pointer_size) { + } => match Layout::from_var(env.arena, *whole_var, env.subs, env.pointer_size) { Ok(Layout::Builtin(Builtin::Bool(_bottom, top))) => { - Pattern::BitLiteral(tag_name == top) + Pattern::BitLiteral(tag_name == &top) } Ok(Layout::Builtin(Builtin::Byte(conversion))) => match conversion.get(&tag_name) { - Some(index) => Pattern::EnumLiteral{ tag_id : *index, enum_size: conversion.len() as u8 }, + Some(index) => Pattern::EnumLiteral { + tag_id: *index, + enum_size: conversion.len() as u8, + }, None => unreachable!("Tag must be in its own type"), }, Ok(layout) => { let mut mono_args = Vec::with_capacity_in(arguments.len(), env.arena); for (_, loc_pat) in arguments { - mono_args.push(from_can_pattern(env, loc_pat.value)); + mono_args.push(from_can_pattern(env, &loc_pat.value)); } + let mut fields = std::vec::Vec::new(); + let union = match roc_types::pretty_print::chase_ext_tag_union( + env.subs, + *whole_var, + &mut fields, + ) { + Ok(()) | Err((_, Content::FlexVar(_))) => { + let mut ctors = std::vec::Vec::with_capacity(fields.len()); + for (tag_name, args) in fields { + ctors.push(crate::pattern::Ctor { + name: tag_name.clone(), + arity: args.len(), + }) + } + + crate::pattern::Union { + alternatives: ctors, + } + } + Err(content) => panic!("invalid content in ext_var: {:?}", content), + }; + Pattern::AppliedTag { - tag_name, + tag_name: tag_name.clone(), arguments: mono_args, + union, layout, } } @@ -1271,11 +1313,11 @@ fn from_can_pattern<'a>( whole_var, destructs, .. - } => match Layout::from_var(env.arena, whole_var, env.subs, env.pointer_size) { + } => match Layout::from_var(env.arena, *whole_var, env.subs, env.pointer_size) { Ok(layout) => { let mut mono_destructs = Vec::with_capacity_in(destructs.len(), env.arena); for loc_rec_des in destructs { - mono_destructs.push(from_can_record_destruct(env, loc_rec_des.value)); + mono_destructs.push(from_can_record_destruct(env, &loc_rec_des.value)); } Pattern::RecordDestructure(mono_destructs, layout) @@ -1287,14 +1329,14 @@ fn from_can_pattern<'a>( fn from_can_record_destruct<'a>( env: &mut Env<'a, '_>, - can_rd: roc_can::pattern::RecordDestruct, + can_rd: &roc_can::pattern::RecordDestruct, ) -> RecordDestruct<'a> { RecordDestruct { - label: can_rd.label, + label: can_rd.label.clone(), symbol: can_rd.symbol, - guard: match can_rd.guard { + guard: match &can_rd.guard { None => None, - Some((_, loc_pattern)) => Some(from_can_pattern(env, loc_pattern.value)), + Some((_, loc_pattern)) => Some(from_can_pattern(env, &loc_pattern.value)), }, } } diff --git a/compiler/mono/src/pattern.rs b/compiler/mono/src/pattern.rs index 7068d1e725..c1b1dfb801 100644 --- a/compiler/mono/src/pattern.rs +++ b/compiler/mono/src/pattern.rs @@ -6,14 +6,13 @@ use self::Pattern::*; #[derive(Clone, Debug, PartialEq)] pub struct Union { - alternatives: Vec, - num_alts: usize, + pub alternatives: Vec, } #[derive(Clone, Debug, PartialEq)] pub struct Ctor { - name: TagName, - arity: usize, + pub name: TagName, + pub arity: usize, } #[derive(Clone, Debug, PartialEq)] @@ -25,27 +24,31 @@ pub enum Pattern { #[derive(Clone, Debug, PartialEq)] pub enum Literal { - Num(i64), Int(i64), + Bit(bool), + Byte(u8), Float(f64), Str(Box), } -fn simplify(pattern: &roc_can::pattern::Pattern) -> Pattern { +fn simplify<'a>(pattern: &crate::expr::Pattern<'a>) -> Pattern { let mut errors = Vec::new(); simplify_help(pattern, &mut errors) } -fn simplify_help(pattern: &roc_can::pattern::Pattern, errors: &mut Vec) -> Pattern { - use roc_can::pattern::Pattern::*; +fn simplify_help<'a>(pattern: &crate::expr::Pattern<'a>, errors: &mut Vec) -> Pattern { + use crate::expr::Pattern::*; match pattern { IntLiteral(v) => Literal(Literal::Int(*v)), - NumLiteral(_, v) => Literal(Literal::Int(*v)), FloatLiteral(v) => Literal(Literal::Float(*v)), StrLiteral(v) => Literal(Literal::Str(v.clone())), + // TODO make sure these are exhaustive + BitLiteral(b) => Literal(Literal::Bit(*b)), + EnumLiteral { tag_id, .. } => Literal(Literal::Byte(*tag_id)), + Underscore => Anything, Identifier(_) => Anything, RecordDestructure { .. } => { @@ -67,17 +70,14 @@ fn simplify_help(pattern: &roc_can::pattern::Pattern, errors: &mut Vec) - AppliedTag { tag_name, arguments, + union, .. } => { - let union = Union { - alternatives: Vec::new(), - num_alts: 0, - }; let simplified_args: std::vec::Vec<_> = arguments .iter() - .map(|v| simplify_help(&v.1.value, errors)) + .map(|v| simplify_help(&v, errors)) .collect(); - Ctor(union, tag_name.clone(), simplified_args) + Ctor(union.clone(), tag_name.clone(), simplified_args) } } } @@ -99,9 +99,9 @@ pub enum Context { /// Check -pub fn check( +pub fn check<'a>( region: Region, - patterns: &[Located], + patterns: &[Located>], ) -> Result<(), Vec> { let mut errors = Vec::new(); check_patterns(region, Context::BadArg, patterns, &mut errors); @@ -113,43 +113,10 @@ pub fn check( } } -// pub fn check(module: roc_can::module::ModuleOutput) -> Result<(), Vec> { -// let mut errors = Vec::new(); -// check_declarations(&module.declarations, &mut errors); -// -// if errors.is_empty() { -// Ok(()) -// } else { -// Err(errors) -// } -// } -// -// /// CHECK DECLS -// -// fn check_declarations(decls: &[roc_can::def::Declaration], errors: &mut Vec) { -// use roc_can::def::Declaration; -// -// for decl in decls { -// Declaration::Declare(def) => check_def(def, errors), -// Declaration::DeclareRef(defs) => { -// for def in defs { -// check_def(def, errors); -// } -// } -// Declaration::InvalidCycle(_) => {} -// } -// } -// -// fn check_def(def: &roc_can::def::Def, errors: &mut Vec) { -// check_patttern -// -// -// } - -pub fn check_patterns( +pub fn check_patterns<'a>( region: Region, context: Context, - patterns: &[Located], + patterns: &[Located>], errors: &mut Vec, ) { match to_nonredundant_rows(region, patterns) { @@ -197,7 +164,7 @@ fn is_exhaustive(matrix: &PatternMatrix, n: usize) -> PatternMatrix { let alts = ctors.iter().next().unwrap().1; let alt_list = &alts.alternatives; - let num_alts = alts.num_alts; + let num_alts = alt_list.len(); if num_seen < num_alts { let new_matrix = matrix @@ -277,9 +244,9 @@ fn recover_ctor( /// REDUNDANT PATTERNS /// INVARIANT: Produces a list of rows where (forall row. length row == 1) -fn to_nonredundant_rows( +fn to_nonredundant_rows<'a>( overall_region: Region, - patterns: &[Located], + patterns: &[Located>], ) -> Result>, Error> { let mut checked_rows = Vec::with_capacity(patterns.len()); @@ -449,12 +416,8 @@ fn is_complete(matrix: &PatternMatrix) -> Complete { match it.next() { None => Complete::No, - Some(Union { - alternatives, - num_alts, - .. - }) => { - if ctors.len() == *num_alts { + Some(Union { alternatives, .. }) => { + if ctors.len() == alternatives.len() { Complete::Yes(alternatives.to_vec()) } else { Complete::No @@ -473,7 +436,7 @@ fn collect_ctors(matrix: &RefPatternMatrix) -> MutMap { let mut ctors = MutMap::default(); for row in matrix { - if let Some(Ctor(union, name, _)) = row.get(0) { + if let Some(Ctor(union, name, _)) = row.get(row.len() - 1) { ctors.insert(name.clone(), union.clone()); } }