diff --git a/compiler/mono/src/decision_tree.rs b/compiler/mono/src/decision_tree.rs index 436eb43446..16b3ea0a84 100644 --- a/compiler/mono/src/decision_tree.rs +++ b/compiler/mono/src/decision_tree.rs @@ -1,6 +1,7 @@ use crate::exhaustive::{Ctor, RenderAs, TagId, Union}; use crate::ir::{ - BranchInfo, DestructType, Env, Expr, JoinPointId, Literal, Param, Pattern, Procs, Stmt, + BranchInfo, DestructType, Env, Expr, FloatPrecision, IntPrecision, JoinPointId, Literal, Param, + Pattern, Procs, Stmt, }; use crate::layout::{Builtin, Layout, LayoutCache, UnionLayout}; use roc_collections::all::{MutMap, MutSet}; @@ -85,8 +86,8 @@ enum Test<'a> { union: crate::exhaustive::Union, arguments: Vec<(Pattern<'a>, Layout<'a>)>, }, - IsInt(i128), - IsFloat(u64), + IsInt(i128, IntPrecision), + IsFloat(u64, FloatPrecision), IsDecimal(RocDec), IsStr(Box), IsBit(bool), @@ -95,6 +96,7 @@ enum Test<'a> { num_alts: usize, }, } + use std::hash::{Hash, Hasher}; impl<'a> Hash for Test<'a> { fn hash(&self, state: &mut H) { @@ -106,13 +108,15 @@ impl<'a> Hash for Test<'a> { tag_id.hash(state); // The point of this custom implementation is to not hash the tag arguments } - IsInt(v) => { + IsInt(v, width) => { state.write_u8(1); v.hash(state); + width.hash(state); } - IsFloat(v) => { + IsFloat(v, width) => { state.write_u8(2); v.hash(state); + width.hash(state); } IsStr(v) => { state.write_u8(3); @@ -306,8 +310,8 @@ fn tests_are_complete_help(last_test: &Test, number_of_tests: usize) -> bool { Test::IsCtor { union, .. } => number_of_tests == union.alternatives.len(), Test::IsByte { num_alts, .. } => number_of_tests == *num_alts, Test::IsBit(_) => number_of_tests == 2, - Test::IsInt(_) => false, - Test::IsFloat(_) => false, + Test::IsInt(_, _) => false, + Test::IsFloat(_, _) => false, Test::IsDecimal(_) => false, Test::IsStr(_) => false, } @@ -561,8 +565,8 @@ fn test_at_path<'a>( tag_id: *tag_id, num_alts: union.alternatives.len(), }, - IntLiteral(v) => IsInt(*v), - FloatLiteral(v) => IsFloat(*v), + IntLiteral(v, precision) => IsInt(*v, *precision), + FloatLiteral(v, precision) => IsFloat(*v, *precision), DecimalLiteral(v) => IsDecimal(*v), StrLiteral(v) => IsStr(v.clone()), }; @@ -807,8 +811,9 @@ fn to_relevant_branch_help<'a>( _ => None, }, - IntLiteral(int) => match test { - IsInt(is_int) if int == *is_int => { + IntLiteral(int, p1) => match test { + IsInt(is_int, p2) if int == *is_int => { + debug_assert_eq!(p1, *p2); start.extend(end); Some(Branch { goal: branch.goal, @@ -819,8 +824,9 @@ fn to_relevant_branch_help<'a>( _ => None, }, - FloatLiteral(float) => match test { - IsFloat(test_float) if float == *test_float => { + FloatLiteral(float, p1) => match test { + IsFloat(test_float, p2) if float == *test_float => { + debug_assert_eq!(p1, *p2); start.extend(end); Some(Branch { goal: branch.goal, @@ -928,8 +934,8 @@ fn needs_tests(pattern: &Pattern) -> bool { | AppliedTag { .. } | BitLiteral { .. } | EnumLiteral { .. } - | IntLiteral(_) - | FloatLiteral(_) + | IntLiteral(_, _) + | FloatLiteral(_, _) | DecimalLiteral(_) | StrLiteral(_) => true, } @@ -1280,22 +1286,22 @@ fn test_to_equality<'a>( _ => unreachable!("{:?}", (cond_layout, union)), } } - Test::IsInt(test_int) => { + Test::IsInt(test_int, precision) => { // TODO don't downcast i128 here debug_assert!(test_int <= i64::MAX as i128); let lhs = Expr::Literal(Literal::Int(test_int as i128)); let lhs_symbol = env.unique_symbol(); - stores.push((lhs_symbol, Layout::Builtin(Builtin::Int64), lhs)); + stores.push((lhs_symbol, precision.as_layout(), lhs)); (stores, lhs_symbol, rhs_symbol, None) } - Test::IsFloat(test_int) => { + Test::IsFloat(test_int, precision) => { // TODO maybe we can actually use i64 comparison here? let test_float = f64::from_bits(test_int as u64); let lhs = Expr::Literal(Literal::Float(test_float)); let lhs_symbol = env.unique_symbol(); - stores.push((lhs_symbol, Layout::Builtin(Builtin::Float64), lhs)); + stores.push((lhs_symbol, precision.as_layout(), lhs)); (stores, lhs_symbol, rhs_symbol, None) } @@ -1303,7 +1309,7 @@ fn test_to_equality<'a>( Test::IsDecimal(test_dec) => { let lhs = Expr::Literal(Literal::Int(test_dec.0)); let lhs_symbol = env.unique_symbol(); - stores.push((lhs_symbol, Layout::Builtin(Builtin::Int128), lhs)); + stores.push((lhs_symbol, *cond_layout, lhs)); (stores, lhs_symbol, rhs_symbol, None) } @@ -1737,8 +1743,8 @@ fn decide_to_branching<'a>( ); let tag = match test { - Test::IsInt(v) => v as u64, - Test::IsFloat(v) => v as u64, + Test::IsInt(v, _) => v as u64, + Test::IsFloat(v, _) => v as u64, Test::IsBit(v) => v as u64, Test::IsByte { tag_id, .. } => tag_id as u64, Test::IsCtor { tag_id, .. } => tag_id as u64, diff --git a/compiler/mono/src/exhaustive.rs b/compiler/mono/src/exhaustive.rs index 51b20d32aa..77c28d64a1 100644 --- a/compiler/mono/src/exhaustive.rs +++ b/compiler/mono/src/exhaustive.rs @@ -65,8 +65,8 @@ fn simplify(pattern: &crate::ir::Pattern) -> Pattern { use crate::ir::Pattern::*; match pattern { - IntLiteral(v) => Literal(Literal::Int(*v)), - FloatLiteral(v) => Literal(Literal::Float(*v)), + IntLiteral(v, _) => Literal(Literal::Int(*v)), + FloatLiteral(v, _) => Literal(Literal::Float(*v)), DecimalLiteral(v) => Literal(Literal::Decimal(*v)), StrLiteral(v) => Literal(Literal::Str(v.clone())), diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index aaf54ebb5c..6cee53ac32 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -2774,13 +2774,13 @@ pub fn with_hole<'a>( IntOrFloat::SignedIntType(precision) => Stmt::Let( assigned, Expr::Literal(Literal::Int(int)), - Layout::Builtin(int_precision_to_builtin(precision)), + precision.as_layout(), hole, ), IntOrFloat::UnsignedIntType(precision) => Stmt::Let( assigned, Expr::Literal(Literal::Int(int)), - Layout::Builtin(int_precision_to_builtin(precision)), + precision.as_layout(), hole, ), _ => unreachable!("unexpected float precision for integer"), @@ -2792,7 +2792,7 @@ pub fn with_hole<'a>( IntOrFloat::BinaryFloatType(precision) => Stmt::Let( assigned, Expr::Literal(Literal::Float(float)), - Layout::Builtin(float_precision_to_builtin(precision)), + precision.as_layout(), hole, ), IntOrFloat::DecimalFloatType => { @@ -2824,19 +2824,19 @@ pub fn with_hole<'a>( IntOrFloat::SignedIntType(precision) => Stmt::Let( assigned, Expr::Literal(Literal::Int(num.into())), - Layout::Builtin(int_precision_to_builtin(precision)), + precision.as_layout(), hole, ), IntOrFloat::UnsignedIntType(precision) => Stmt::Let( assigned, Expr::Literal(Literal::Int(num.into())), - Layout::Builtin(int_precision_to_builtin(precision)), + precision.as_layout(), hole, ), IntOrFloat::BinaryFloatType(precision) => Stmt::Let( assigned, Expr::Literal(Literal::Float(num as f64)), - Layout::Builtin(float_precision_to_builtin(precision)), + precision.as_layout(), hole, ), IntOrFloat::DecimalFloatType => { @@ -5634,8 +5634,8 @@ fn store_pattern_help<'a>( // do nothing return StorePattern::NotProductive(stmt); } - IntLiteral(_) - | FloatLiteral(_) + IntLiteral(_, _) + | FloatLiteral(_, _) | DecimalLiteral(_) | EnumLiteral { .. } | BitLiteral { .. } @@ -5769,8 +5769,8 @@ fn store_tag_pattern<'a>( Underscore => { // ignore } - IntLiteral(_) - | FloatLiteral(_) + IntLiteral(_, _) + | FloatLiteral(_, _) | DecimalLiteral(_) | EnumLiteral { .. } | BitLiteral { .. } @@ -5845,8 +5845,8 @@ fn store_newtype_pattern<'a>( Underscore => { // ignore } - IntLiteral(_) - | FloatLiteral(_) + IntLiteral(_, _) + | FloatLiteral(_, _) | DecimalLiteral(_) | EnumLiteral { .. } | BitLiteral { .. } @@ -5921,8 +5921,8 @@ fn store_record_destruct<'a>( // internally. But `y` is never used, so we must make sure it't not stored/loaded. return StorePattern::NotProductive(stmt); } - IntLiteral(_) - | FloatLiteral(_) + IntLiteral(_, _) + | FloatLiteral(_, _) | DecimalLiteral(_) | EnumLiteral { .. } | BitLiteral { .. } @@ -6892,8 +6892,8 @@ fn call_specialized_proc<'a>( pub enum Pattern<'a> { Identifier(Symbol), Underscore, - IntLiteral(i128), - FloatLiteral(u64), + IntLiteral(i128, IntPrecision), + FloatLiteral(u64, FloatPrecision), DecimalLiteral(RocDec), BitLiteral { value: bool, @@ -6971,7 +6971,22 @@ fn from_can_pattern_help<'a>( match can_pattern { Underscore => Ok(Pattern::Underscore), Identifier(symbol) => Ok(Pattern::Identifier(*symbol)), - IntLiteral(_, _, int) => Ok(Pattern::IntLiteral(*int as i128)), + IntLiteral(var, _, int) => { + let precision = { + match num_argument_to_int_or_float(env.subs, env.ptr_bytes, *var, false) { + IntOrFloat::SignedIntType(precision) + | IntOrFloat::UnsignedIntType(precision) => precision, + other => { + panic!( + "Invalid precision for int pattern: {:?} has {:?}", + can_pattern, other + ) + } + } + }; + + Ok(Pattern::IntLiteral(*int as i128, precision)) + } FloatLiteral(var, float_str, float) => { // TODO: Can I reuse num_argument_to_int_or_float here if I pass in true? match num_argument_to_int_or_float(env.subs, env.ptr_bytes, *var, true) { @@ -6981,7 +6996,9 @@ fn from_can_pattern_help<'a>( IntOrFloat::UnsignedIntType(_) => { panic!("Invalid percision for float literal = {:?}", var) } - IntOrFloat::BinaryFloatType(_) => Ok(Pattern::FloatLiteral(f64::to_bits(*float))), + IntOrFloat::BinaryFloatType(precision) => { + Ok(Pattern::FloatLiteral(f64::to_bits(*float), precision)) + } IntOrFloat::DecimalFloatType => { let dec = match RocDec::from_str(float_str) { Some(d) => d, @@ -7003,9 +7020,15 @@ fn from_can_pattern_help<'a>( } NumLiteral(var, num_str, num) => { match num_argument_to_int_or_float(env.subs, env.ptr_bytes, *var, false) { - IntOrFloat::SignedIntType(_) => Ok(Pattern::IntLiteral(*num as i128)), - IntOrFloat::UnsignedIntType(_) => Ok(Pattern::IntLiteral(*num as i128)), - IntOrFloat::BinaryFloatType(_) => Ok(Pattern::FloatLiteral(*num as u64)), + IntOrFloat::SignedIntType(precision) => { + Ok(Pattern::IntLiteral(*num as i128, precision)) + } + IntOrFloat::UnsignedIntType(precision) => { + Ok(Pattern::IntLiteral(*num as i128, precision)) + } + IntOrFloat::BinaryFloatType(precision) => { + Ok(Pattern::FloatLiteral(*num as u64, precision)) + } IntOrFloat::DecimalFloatType => { let dec = match RocDec::from_str(num_str) { Some(d) => d, @@ -7587,7 +7610,7 @@ fn from_can_record_destruct<'a>( }) } -#[derive(Debug)] +#[derive(Debug, Clone, Copy, PartialEq, Hash)] pub enum IntPrecision { Usize, I128, @@ -7597,11 +7620,45 @@ pub enum IntPrecision { I8, } +impl IntPrecision { + pub fn as_layout(&self) -> Layout<'static> { + Layout::Builtin(self.as_builtin()) + } + + pub fn as_builtin(&self) -> Builtin<'static> { + use IntPrecision::*; + match self { + I128 => Builtin::Int128, + I64 => Builtin::Int64, + I32 => Builtin::Int32, + I16 => Builtin::Int16, + I8 => Builtin::Int8, + Usize => Builtin::Usize, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Hash)] pub enum FloatPrecision { F64, F32, } +impl FloatPrecision { + pub fn as_layout(&self) -> Layout<'static> { + Layout::Builtin(self.as_builtin()) + } + + pub fn as_builtin(&self) -> Builtin<'static> { + use FloatPrecision::*; + match self { + F64 => Builtin::Float64, + F32 => Builtin::Float32, + } + } +} + +#[derive(Debug)] pub enum IntOrFloat { SignedIntType(IntPrecision), UnsignedIntType(IntPrecision), @@ -7609,26 +7666,6 @@ pub enum IntOrFloat { DecimalFloatType, } -fn float_precision_to_builtin(precision: FloatPrecision) -> Builtin<'static> { - use FloatPrecision::*; - match precision { - F64 => Builtin::Float64, - F32 => Builtin::Float32, - } -} - -fn int_precision_to_builtin(precision: IntPrecision) -> Builtin<'static> { - use IntPrecision::*; - match precision { - I128 => Builtin::Int128, - I64 => Builtin::Int64, - I32 => Builtin::Int32, - I16 => Builtin::Int16, - I8 => Builtin::Int8, - Usize => Builtin::Usize, - } -} - /// Given the `a` in `Num a`, determines whether it's an int or a float pub fn num_argument_to_int_or_float( subs: &Subs, diff --git a/compiler/test_gen/src/gen_num.rs b/compiler/test_gen/src/gen_num.rs index 29bb479855..9d9d9e1390 100644 --- a/compiler/test_gen/src/gen_num.rs +++ b/compiler/test_gen/src/gen_num.rs @@ -1778,4 +1778,48 @@ mod gen_num { u32 ); } + + #[test] + fn when_on_i32() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + x : I32 + x = 0 + + main : I32 + main = + when x is + 0 -> 42 + _ -> -1 + "# + ), + 42, + i32 + ); + } + + #[test] + fn when_on_i16() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + x : I16 + x = 0 + + main : I16 + main = + when x is + 0 -> 42 + _ -> -1 + "# + ), + 42, + i16 + ); + } }