respect int/float precision in pattern matchs

This commit is contained in:
Folkert 2021-09-18 22:55:34 +02:00
parent cde9f97415
commit ada331567a
4 changed files with 153 additions and 66 deletions

View file

@ -1,6 +1,7 @@
use crate::exhaustive::{Ctor, RenderAs, TagId, Union}; use crate::exhaustive::{Ctor, RenderAs, TagId, Union};
use crate::ir::{ use crate::ir::{
BranchInfo, DestructType, Env, Expr, JoinPointId, Literal, Param, Pattern, Procs, Stmt, BranchInfo, DestructType, Env, Expr, FloatPrecision, IntPrecision, JoinPointId, Literal, Param,
Pattern, Procs, Stmt,
}; };
use crate::layout::{Builtin, Layout, LayoutCache, UnionLayout}; use crate::layout::{Builtin, Layout, LayoutCache, UnionLayout};
use roc_collections::all::{MutMap, MutSet}; use roc_collections::all::{MutMap, MutSet};
@ -85,8 +86,8 @@ enum Test<'a> {
union: crate::exhaustive::Union, union: crate::exhaustive::Union,
arguments: Vec<(Pattern<'a>, Layout<'a>)>, arguments: Vec<(Pattern<'a>, Layout<'a>)>,
}, },
IsInt(i128), IsInt(i128, IntPrecision),
IsFloat(u64), IsFloat(u64, FloatPrecision),
IsDecimal(RocDec), IsDecimal(RocDec),
IsStr(Box<str>), IsStr(Box<str>),
IsBit(bool), IsBit(bool),
@ -95,6 +96,7 @@ enum Test<'a> {
num_alts: usize, num_alts: usize,
}, },
} }
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
impl<'a> Hash for Test<'a> { impl<'a> Hash for Test<'a> {
fn hash<H: Hasher>(&self, state: &mut H) { fn hash<H: Hasher>(&self, state: &mut H) {
@ -106,13 +108,15 @@ impl<'a> Hash for Test<'a> {
tag_id.hash(state); tag_id.hash(state);
// The point of this custom implementation is to not hash the tag arguments // The point of this custom implementation is to not hash the tag arguments
} }
IsInt(v) => { IsInt(v, width) => {
state.write_u8(1); state.write_u8(1);
v.hash(state); v.hash(state);
width.hash(state);
} }
IsFloat(v) => { IsFloat(v, width) => {
state.write_u8(2); state.write_u8(2);
v.hash(state); v.hash(state);
width.hash(state);
} }
IsStr(v) => { IsStr(v) => {
state.write_u8(3); state.write_u8(3);
@ -306,8 +310,8 @@ fn tests_are_complete_help(last_test: &Test, number_of_tests: usize) -> bool {
Test::IsCtor { union, .. } => number_of_tests == union.alternatives.len(), Test::IsCtor { union, .. } => number_of_tests == union.alternatives.len(),
Test::IsByte { num_alts, .. } => number_of_tests == *num_alts, Test::IsByte { num_alts, .. } => number_of_tests == *num_alts,
Test::IsBit(_) => number_of_tests == 2, Test::IsBit(_) => number_of_tests == 2,
Test::IsInt(_) => false, Test::IsInt(_, _) => false,
Test::IsFloat(_) => false, Test::IsFloat(_, _) => false,
Test::IsDecimal(_) => false, Test::IsDecimal(_) => false,
Test::IsStr(_) => false, Test::IsStr(_) => false,
} }
@ -561,8 +565,8 @@ fn test_at_path<'a>(
tag_id: *tag_id, tag_id: *tag_id,
num_alts: union.alternatives.len(), num_alts: union.alternatives.len(),
}, },
IntLiteral(v) => IsInt(*v), IntLiteral(v, precision) => IsInt(*v, *precision),
FloatLiteral(v) => IsFloat(*v), FloatLiteral(v, precision) => IsFloat(*v, *precision),
DecimalLiteral(v) => IsDecimal(*v), DecimalLiteral(v) => IsDecimal(*v),
StrLiteral(v) => IsStr(v.clone()), StrLiteral(v) => IsStr(v.clone()),
}; };
@ -807,8 +811,9 @@ fn to_relevant_branch_help<'a>(
_ => None, _ => None,
}, },
IntLiteral(int) => match test { IntLiteral(int, p1) => match test {
IsInt(is_int) if int == *is_int => { IsInt(is_int, p2) if int == *is_int => {
debug_assert_eq!(p1, *p2);
start.extend(end); start.extend(end);
Some(Branch { Some(Branch {
goal: branch.goal, goal: branch.goal,
@ -819,8 +824,9 @@ fn to_relevant_branch_help<'a>(
_ => None, _ => None,
}, },
FloatLiteral(float) => match test { FloatLiteral(float, p1) => match test {
IsFloat(test_float) if float == *test_float => { IsFloat(test_float, p2) if float == *test_float => {
debug_assert_eq!(p1, *p2);
start.extend(end); start.extend(end);
Some(Branch { Some(Branch {
goal: branch.goal, goal: branch.goal,
@ -928,8 +934,8 @@ fn needs_tests(pattern: &Pattern) -> bool {
| AppliedTag { .. } | AppliedTag { .. }
| BitLiteral { .. } | BitLiteral { .. }
| EnumLiteral { .. } | EnumLiteral { .. }
| IntLiteral(_) | IntLiteral(_, _)
| FloatLiteral(_) | FloatLiteral(_, _)
| DecimalLiteral(_) | DecimalLiteral(_)
| StrLiteral(_) => true, | StrLiteral(_) => true,
} }
@ -1280,22 +1286,22 @@ fn test_to_equality<'a>(
_ => unreachable!("{:?}", (cond_layout, union)), _ => unreachable!("{:?}", (cond_layout, union)),
} }
} }
Test::IsInt(test_int) => { Test::IsInt(test_int, precision) => {
// TODO don't downcast i128 here // TODO don't downcast i128 here
debug_assert!(test_int <= i64::MAX as i128); debug_assert!(test_int <= i64::MAX as i128);
let lhs = Expr::Literal(Literal::Int(test_int as i128)); let lhs = Expr::Literal(Literal::Int(test_int as i128));
let lhs_symbol = env.unique_symbol(); let lhs_symbol = env.unique_symbol();
stores.push((lhs_symbol, Layout::Builtin(Builtin::Int64), lhs)); stores.push((lhs_symbol, precision.as_layout(), lhs));
(stores, lhs_symbol, rhs_symbol, None) (stores, lhs_symbol, rhs_symbol, None)
} }
Test::IsFloat(test_int) => { Test::IsFloat(test_int, precision) => {
// TODO maybe we can actually use i64 comparison here? // TODO maybe we can actually use i64 comparison here?
let test_float = f64::from_bits(test_int as u64); let test_float = f64::from_bits(test_int as u64);
let lhs = Expr::Literal(Literal::Float(test_float)); let lhs = Expr::Literal(Literal::Float(test_float));
let lhs_symbol = env.unique_symbol(); let lhs_symbol = env.unique_symbol();
stores.push((lhs_symbol, Layout::Builtin(Builtin::Float64), lhs)); stores.push((lhs_symbol, precision.as_layout(), lhs));
(stores, lhs_symbol, rhs_symbol, None) (stores, lhs_symbol, rhs_symbol, None)
} }
@ -1303,7 +1309,7 @@ fn test_to_equality<'a>(
Test::IsDecimal(test_dec) => { Test::IsDecimal(test_dec) => {
let lhs = Expr::Literal(Literal::Int(test_dec.0)); let lhs = Expr::Literal(Literal::Int(test_dec.0));
let lhs_symbol = env.unique_symbol(); let lhs_symbol = env.unique_symbol();
stores.push((lhs_symbol, Layout::Builtin(Builtin::Int128), lhs)); stores.push((lhs_symbol, *cond_layout, lhs));
(stores, lhs_symbol, rhs_symbol, None) (stores, lhs_symbol, rhs_symbol, None)
} }
@ -1737,8 +1743,8 @@ fn decide_to_branching<'a>(
); );
let tag = match test { let tag = match test {
Test::IsInt(v) => v as u64, Test::IsInt(v, _) => v as u64,
Test::IsFloat(v) => v as u64, Test::IsFloat(v, _) => v as u64,
Test::IsBit(v) => v as u64, Test::IsBit(v) => v as u64,
Test::IsByte { tag_id, .. } => tag_id as u64, Test::IsByte { tag_id, .. } => tag_id as u64,
Test::IsCtor { tag_id, .. } => tag_id as u64, Test::IsCtor { tag_id, .. } => tag_id as u64,

View file

@ -65,8 +65,8 @@ fn simplify(pattern: &crate::ir::Pattern) -> Pattern {
use crate::ir::Pattern::*; use crate::ir::Pattern::*;
match pattern { match pattern {
IntLiteral(v) => Literal(Literal::Int(*v)), IntLiteral(v, _) => Literal(Literal::Int(*v)),
FloatLiteral(v) => Literal(Literal::Float(*v)), FloatLiteral(v, _) => Literal(Literal::Float(*v)),
DecimalLiteral(v) => Literal(Literal::Decimal(*v)), DecimalLiteral(v) => Literal(Literal::Decimal(*v)),
StrLiteral(v) => Literal(Literal::Str(v.clone())), StrLiteral(v) => Literal(Literal::Str(v.clone())),

View file

@ -2774,13 +2774,13 @@ pub fn with_hole<'a>(
IntOrFloat::SignedIntType(precision) => Stmt::Let( IntOrFloat::SignedIntType(precision) => Stmt::Let(
assigned, assigned,
Expr::Literal(Literal::Int(int)), Expr::Literal(Literal::Int(int)),
Layout::Builtin(int_precision_to_builtin(precision)), precision.as_layout(),
hole, hole,
), ),
IntOrFloat::UnsignedIntType(precision) => Stmt::Let( IntOrFloat::UnsignedIntType(precision) => Stmt::Let(
assigned, assigned,
Expr::Literal(Literal::Int(int)), Expr::Literal(Literal::Int(int)),
Layout::Builtin(int_precision_to_builtin(precision)), precision.as_layout(),
hole, hole,
), ),
_ => unreachable!("unexpected float precision for integer"), _ => unreachable!("unexpected float precision for integer"),
@ -2792,7 +2792,7 @@ pub fn with_hole<'a>(
IntOrFloat::BinaryFloatType(precision) => Stmt::Let( IntOrFloat::BinaryFloatType(precision) => Stmt::Let(
assigned, assigned,
Expr::Literal(Literal::Float(float)), Expr::Literal(Literal::Float(float)),
Layout::Builtin(float_precision_to_builtin(precision)), precision.as_layout(),
hole, hole,
), ),
IntOrFloat::DecimalFloatType => { IntOrFloat::DecimalFloatType => {
@ -2824,19 +2824,19 @@ pub fn with_hole<'a>(
IntOrFloat::SignedIntType(precision) => Stmt::Let( IntOrFloat::SignedIntType(precision) => Stmt::Let(
assigned, assigned,
Expr::Literal(Literal::Int(num.into())), Expr::Literal(Literal::Int(num.into())),
Layout::Builtin(int_precision_to_builtin(precision)), precision.as_layout(),
hole, hole,
), ),
IntOrFloat::UnsignedIntType(precision) => Stmt::Let( IntOrFloat::UnsignedIntType(precision) => Stmt::Let(
assigned, assigned,
Expr::Literal(Literal::Int(num.into())), Expr::Literal(Literal::Int(num.into())),
Layout::Builtin(int_precision_to_builtin(precision)), precision.as_layout(),
hole, hole,
), ),
IntOrFloat::BinaryFloatType(precision) => Stmt::Let( IntOrFloat::BinaryFloatType(precision) => Stmt::Let(
assigned, assigned,
Expr::Literal(Literal::Float(num as f64)), Expr::Literal(Literal::Float(num as f64)),
Layout::Builtin(float_precision_to_builtin(precision)), precision.as_layout(),
hole, hole,
), ),
IntOrFloat::DecimalFloatType => { IntOrFloat::DecimalFloatType => {
@ -5634,8 +5634,8 @@ fn store_pattern_help<'a>(
// do nothing // do nothing
return StorePattern::NotProductive(stmt); return StorePattern::NotProductive(stmt);
} }
IntLiteral(_) IntLiteral(_, _)
| FloatLiteral(_) | FloatLiteral(_, _)
| DecimalLiteral(_) | DecimalLiteral(_)
| EnumLiteral { .. } | EnumLiteral { .. }
| BitLiteral { .. } | BitLiteral { .. }
@ -5769,8 +5769,8 @@ fn store_tag_pattern<'a>(
Underscore => { Underscore => {
// ignore // ignore
} }
IntLiteral(_) IntLiteral(_, _)
| FloatLiteral(_) | FloatLiteral(_, _)
| DecimalLiteral(_) | DecimalLiteral(_)
| EnumLiteral { .. } | EnumLiteral { .. }
| BitLiteral { .. } | BitLiteral { .. }
@ -5845,8 +5845,8 @@ fn store_newtype_pattern<'a>(
Underscore => { Underscore => {
// ignore // ignore
} }
IntLiteral(_) IntLiteral(_, _)
| FloatLiteral(_) | FloatLiteral(_, _)
| DecimalLiteral(_) | DecimalLiteral(_)
| EnumLiteral { .. } | EnumLiteral { .. }
| BitLiteral { .. } | BitLiteral { .. }
@ -5921,8 +5921,8 @@ fn store_record_destruct<'a>(
// internally. But `y` is never used, so we must make sure it't not stored/loaded. // internally. But `y` is never used, so we must make sure it't not stored/loaded.
return StorePattern::NotProductive(stmt); return StorePattern::NotProductive(stmt);
} }
IntLiteral(_) IntLiteral(_, _)
| FloatLiteral(_) | FloatLiteral(_, _)
| DecimalLiteral(_) | DecimalLiteral(_)
| EnumLiteral { .. } | EnumLiteral { .. }
| BitLiteral { .. } | BitLiteral { .. }
@ -6892,8 +6892,8 @@ fn call_specialized_proc<'a>(
pub enum Pattern<'a> { pub enum Pattern<'a> {
Identifier(Symbol), Identifier(Symbol),
Underscore, Underscore,
IntLiteral(i128), IntLiteral(i128, IntPrecision),
FloatLiteral(u64), FloatLiteral(u64, FloatPrecision),
DecimalLiteral(RocDec), DecimalLiteral(RocDec),
BitLiteral { BitLiteral {
value: bool, value: bool,
@ -6971,7 +6971,22 @@ fn from_can_pattern_help<'a>(
match can_pattern { match can_pattern {
Underscore => Ok(Pattern::Underscore), Underscore => Ok(Pattern::Underscore),
Identifier(symbol) => Ok(Pattern::Identifier(*symbol)), Identifier(symbol) => Ok(Pattern::Identifier(*symbol)),
IntLiteral(_, _, int) => Ok(Pattern::IntLiteral(*int as i128)), IntLiteral(var, _, int) => {
let precision = {
match num_argument_to_int_or_float(env.subs, env.ptr_bytes, *var, false) {
IntOrFloat::SignedIntType(precision)
| IntOrFloat::UnsignedIntType(precision) => precision,
other => {
panic!(
"Invalid precision for int pattern: {:?} has {:?}",
can_pattern, other
)
}
}
};
Ok(Pattern::IntLiteral(*int as i128, precision))
}
FloatLiteral(var, float_str, float) => { FloatLiteral(var, float_str, float) => {
// TODO: Can I reuse num_argument_to_int_or_float here if I pass in true? // TODO: Can I reuse num_argument_to_int_or_float here if I pass in true?
match num_argument_to_int_or_float(env.subs, env.ptr_bytes, *var, true) { match num_argument_to_int_or_float(env.subs, env.ptr_bytes, *var, true) {
@ -6981,7 +6996,9 @@ fn from_can_pattern_help<'a>(
IntOrFloat::UnsignedIntType(_) => { IntOrFloat::UnsignedIntType(_) => {
panic!("Invalid percision for float literal = {:?}", var) panic!("Invalid percision for float literal = {:?}", var)
} }
IntOrFloat::BinaryFloatType(_) => Ok(Pattern::FloatLiteral(f64::to_bits(*float))), IntOrFloat::BinaryFloatType(precision) => {
Ok(Pattern::FloatLiteral(f64::to_bits(*float), precision))
}
IntOrFloat::DecimalFloatType => { IntOrFloat::DecimalFloatType => {
let dec = match RocDec::from_str(float_str) { let dec = match RocDec::from_str(float_str) {
Some(d) => d, Some(d) => d,
@ -7003,9 +7020,15 @@ fn from_can_pattern_help<'a>(
} }
NumLiteral(var, num_str, num) => { NumLiteral(var, num_str, num) => {
match num_argument_to_int_or_float(env.subs, env.ptr_bytes, *var, false) { match num_argument_to_int_or_float(env.subs, env.ptr_bytes, *var, false) {
IntOrFloat::SignedIntType(_) => Ok(Pattern::IntLiteral(*num as i128)), IntOrFloat::SignedIntType(precision) => {
IntOrFloat::UnsignedIntType(_) => Ok(Pattern::IntLiteral(*num as i128)), Ok(Pattern::IntLiteral(*num as i128, precision))
IntOrFloat::BinaryFloatType(_) => Ok(Pattern::FloatLiteral(*num as u64)), }
IntOrFloat::UnsignedIntType(precision) => {
Ok(Pattern::IntLiteral(*num as i128, precision))
}
IntOrFloat::BinaryFloatType(precision) => {
Ok(Pattern::FloatLiteral(*num as u64, precision))
}
IntOrFloat::DecimalFloatType => { IntOrFloat::DecimalFloatType => {
let dec = match RocDec::from_str(num_str) { let dec = match RocDec::from_str(num_str) {
Some(d) => d, Some(d) => d,
@ -7587,7 +7610,7 @@ fn from_can_record_destruct<'a>(
}) })
} }
#[derive(Debug)] #[derive(Debug, Clone, Copy, PartialEq, Hash)]
pub enum IntPrecision { pub enum IntPrecision {
Usize, Usize,
I128, I128,
@ -7597,29 +7620,14 @@ pub enum IntPrecision {
I8, I8,
} }
pub enum FloatPrecision { impl IntPrecision {
F64, pub fn as_layout(&self) -> Layout<'static> {
F32, Layout::Builtin(self.as_builtin())
}
pub enum IntOrFloat {
SignedIntType(IntPrecision),
UnsignedIntType(IntPrecision),
BinaryFloatType(FloatPrecision),
DecimalFloatType,
}
fn float_precision_to_builtin(precision: FloatPrecision) -> Builtin<'static> {
use FloatPrecision::*;
match precision {
F64 => Builtin::Float64,
F32 => Builtin::Float32,
} }
}
fn int_precision_to_builtin(precision: IntPrecision) -> Builtin<'static> { pub fn as_builtin(&self) -> Builtin<'static> {
use IntPrecision::*; use IntPrecision::*;
match precision { match self {
I128 => Builtin::Int128, I128 => Builtin::Int128,
I64 => Builtin::Int64, I64 => Builtin::Int64,
I32 => Builtin::Int32, I32 => Builtin::Int32,
@ -7627,6 +7635,35 @@ fn int_precision_to_builtin(precision: IntPrecision) -> Builtin<'static> {
I8 => Builtin::Int8, I8 => Builtin::Int8,
Usize => Builtin::Usize, Usize => Builtin::Usize,
} }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Hash)]
pub enum FloatPrecision {
F64,
F32,
}
impl FloatPrecision {
pub fn as_layout(&self) -> Layout<'static> {
Layout::Builtin(self.as_builtin())
}
pub fn as_builtin(&self) -> Builtin<'static> {
use FloatPrecision::*;
match self {
F64 => Builtin::Float64,
F32 => Builtin::Float32,
}
}
}
#[derive(Debug)]
pub enum IntOrFloat {
SignedIntType(IntPrecision),
UnsignedIntType(IntPrecision),
BinaryFloatType(FloatPrecision),
DecimalFloatType,
} }
/// Given the `a` in `Num a`, determines whether it's an int or a float /// Given the `a` in `Num a`, determines whether it's an int or a float

View file

@ -1778,4 +1778,48 @@ mod gen_num {
u32 u32
); );
} }
#[test]
fn when_on_i32() {
assert_evals_to!(
indoc!(
r#"
app "test" provides [ main ] to "./platform"
x : I32
x = 0
main : I32
main =
when x is
0 -> 42
_ -> -1
"#
),
42,
i32
);
}
#[test]
fn when_on_i16() {
assert_evals_to!(
indoc!(
r#"
app "test" provides [ main ] to "./platform"
x : I16
x = 0
main : I16
main =
when x is
0 -> 42
_ -> -1
"#
),
42,
i16
);
}
} }