Implement hashing tag discriminants in derivers, rather than using low-level

This makes it so we can decide the discriminant in the front-end. With
this, we can also now revert the `LowLevel::TagDiscriminant`
introductions.
This commit is contained in:
Ayaz Hafiz 2022-10-05 12:58:04 -05:00
parent a308ebb38c
commit cb96a64259
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
3 changed files with 75 additions and 81 deletions

View file

@ -3,11 +3,12 @@
use std::iter::once;
use roc_can::{
def::Def,
expr::{AnnotatedMark, ClosureData, Expr, Recursive, WhenBranch, WhenBranchPattern},
expr::{AnnotatedMark, ClosureData, Expr, IntValue, Recursive, WhenBranch, WhenBranchPattern},
num::{IntBound, IntLitWidth},
pattern::Pattern,
};
use roc_derive_key::hash::FlatHashKey;
use roc_error_macros::internal_error;
use roc_module::{
called_via::CalledVia,
ident::{Lowercase, TagName},
@ -15,6 +16,7 @@ use roc_module::{
};
use roc_region::all::{Loc, Region};
use roc_types::{
num::int_lit_width_to_variable,
subs::{
Content, ExhaustiveMark, FlatType, GetSubsSlice, LambdaSet, OptVariable, RecordFields,
RedundantMark, SubsIndex, SubsSlice, UnionLambdas, UnionTags, Variable, VariableSubsSlice,
@ -150,49 +152,40 @@ fn hash_tag_union(
//
// hash_union : hasher, [ A t11 .. t1n, ..., Q tq1 .. tqm ] -> hasher | hasher has Hasher
// hash_union = \hasher, union ->
// discrHasher = Hash.hash hasher (@tag_discriminant union)
// when union is
// A x11 .. x1n -> Hash.hash (... (Hash.hash discrHasher x11) ...) x1n
// A x11 .. x1n -> Hash.hash (... (Hash.hash (Hash.uN hasher 0) x11) ...) x1n
// ...
// Q xq1 .. xqm -> Hash.hash (... (Hash.hash discrHasher xq1) ...) xqm
// Q xq1 .. xqm -> Hash.hash (... (Hash.hash (Hash.uN hasher (q - 1)) xq1) ...) xqm
//
// where `Hash.uN` is the appropriate hasher for the discriminant value - typically a `u8`, but
// if there are more than `u8::MAX` tags, we use `u16`, and so on.
let union_sym = env.new_symbol("union");
let hasher_sym = env.new_symbol("hasher");
let hasher_var = synth_var(env.subs, Content::FlexAbleVar(None, Symbol::HASH_HASHER));
let discr_hasher_sym = env.new_symbol("discrHasher");
// discrHasher = ...
let (discr_hasher_var, discr_hasher_def) = {
let discr_expr = Expr::RunLowLevel {
op: roc_module::low_level::LowLevel::TagDiscriminant,
args: vec![(union_var, Expr::Var(union_sym))],
ret_var: Variable::U16,
};
let discr_var = Variable::U16;
let (discr_hasher_var, disc_hasher_expr) = call_hash_add_u16(
env,
(hasher_var, Expr::Var(hasher_sym)),
(discr_var, discr_expr),
);
let discr_def = Def {
loc_pattern: Loc::at_zero(Pattern::Identifier(discr_hasher_sym)),
loc_expr: Loc::at_zero(disc_hasher_expr),
expr_var: discr_hasher_var,
pattern_vars: once((discr_hasher_sym, discr_hasher_var)).collect(),
annotation: None,
};
(discr_hasher_var, discr_def)
let (discr_width, discr_precision_var, hash_discr_member) = if union_tags.len() > u64::MAX as _
{
// Should never happen, `usize` isn't more than 64 bits on most machines, but who knows?
// Maybe someday soon someone will try to compile a huge Roc program on a 128-bit one.
internal_error!("new record unlocked: you fit more than 18 billion, billion tags in a Roc program, and the compiler didn't fall over! But it will now. 🤯")
} else if union_tags.len() > u32::MAX as _ {
(IntLitWidth::U64, Variable::UNSIGNED64, Symbol::HASH_ADD_U64)
} else if union_tags.len() > u16::MAX as _ {
(IntLitWidth::U32, Variable::UNSIGNED32, Symbol::HASH_ADD_U32)
} else if union_tags.len() > u8::MAX as _ {
(IntLitWidth::U16, Variable::UNSIGNED16, Symbol::HASH_ADD_U16)
} else {
(IntLitWidth::U8, Variable::UNSIGNED8, Symbol::HASH_ADD_U8)
};
let discr_num_var = int_lit_width_to_variable(discr_width);
// Build the branches of the body
let whole_hasher_var = env.subs.fresh_unnamed_flex_var();
let branches = union_tags
.iter_all()
.map(|(tag, payloads)| {
.enumerate()
.map(|(discr_n, (tag, payloads))| {
// A
let tag_name = env.subs[tag].clone();
// t11 .. t1n
@ -218,9 +211,26 @@ fn hash_tag_union(
degenerate: false,
};
// discrHasher = (Hash.uN hasher n)
let (discr_hasher_var, disc_hasher_expr) = call_hash_ability_member(
env,
hash_discr_member,
(hasher_var, Expr::Var(hasher_sym)),
(
discr_num_var,
Expr::Int(
discr_num_var,
discr_precision_var,
format!("{}", discr_n).into_boxed_str(),
IntValue::I128((discr_n as i128).to_ne_bytes()),
IntBound::Exact(discr_width),
),
),
);
// Fold up `Hash.hash (... (Hash.hash discrHasher x11) ...) x1n`
let (body_var, body_expr) = (payload_vars.into_iter()).zip(payload_syms).fold(
(discr_hasher_var, Expr::Var(discr_hasher_sym)),
(discr_hasher_var, disc_hasher_expr),
|total_hasher, (payload_var, payload_sym)| {
call_hash_hash(env, total_hasher, (payload_var, Expr::Var(payload_sym)))
},
@ -250,12 +260,6 @@ fn hash_tag_union(
exhaustive: ExhaustiveMark::known_exhaustive(),
};
let body_var = when_var;
let body_expr = Expr::LetNonRec(
Box::new(discr_hasher_def),
Box::new(Loc::at_zero(when_expr)),
);
// Finally, build the closure
// \hasher, rcd -> body
build_outer_derived_closure(
@ -263,7 +267,7 @@ fn hash_tag_union(
fn_name,
(hasher_var, hasher_sym),
(union_var, Pattern::Identifier(union_sym)),
(body_var, body_expr),
(when_var, when_expr),
)
}
@ -351,14 +355,6 @@ fn hash_newtype_tag_union(
)
}
fn call_hash_add_u16(
env: &mut Env<'_>,
hasher: (Variable, Expr),
val: (Variable, Expr),
) -> (Variable, Expr) {
call_hash_ability_member(env, Symbol::HASH_ADD_U16, hasher, val)
}
fn call_hash_hash(
env: &mut Env<'_>,
hasher: (Variable, Expr),
@ -400,7 +396,7 @@ fn call_hash_ability_member(
env.unify(exposed_hash_fn_var, this_hash_fn_var);
// Hash.hash : hasher, (typeof field) -[clos]-> hasher | hasher has Hasher, (typeof field) has Hash
let hash_fn_head = Expr::AbilityMember(Symbol::HASH_HASH, None, this_hash_fn_var);
let hash_fn_head = Expr::AbilityMember(member, None, this_hash_fn_var);
let hash_fn_data = Box::new((
this_hash_fn_var,
Loc::at_zero(hash_fn_head),

View file

@ -244,14 +244,14 @@ fn tag_two_labels() {
# @<1>: [[hash_[A 3,B 1](0)]]
#Derived.hash_[A 3,B 1] =
\#Derived.hasher, #Derived.union ->
#Derived.discrHasher =
Hash.hash #Derived.hasher (@tag_discriminant #Derived.union)
when #Derived.union is
A #Derived.4 #Derived.5 #Derived.6 ->
A #Derived.3 #Derived.4 #Derived.5 ->
Hash.hash
(Hash.hash (Hash.hash #Derived.discrHasher #Derived.4) #Derived.5)
#Derived.6
B #Derived.7 -> Hash.hash #Derived.discrHasher #Derived.7
(Hash.hash
(Hash.hash (Hash.addU8 #Derived.hasher 0) #Derived.3)
#Derived.4)
#Derived.5
B #Derived.6 -> Hash.hash (Hash.addU8 #Derived.hasher 1) #Derived.6
"###
)
})
@ -268,11 +268,9 @@ fn tag_two_labels_no_payloads() {
# @<1>: [[hash_[A 0,B 0](0)]]
#Derived.hash_[A 0,B 0] =
\#Derived.hasher, #Derived.union ->
#Derived.discrHasher =
Hash.hash #Derived.hasher (@tag_discriminant #Derived.union)
when #Derived.union is
A -> #Derived.discrHasher
B -> #Derived.discrHasher
A -> Hash.addU8 #Derived.hasher 0
B -> Hash.addU8 #Derived.hasher 1
"###
)
})
@ -289,12 +287,12 @@ fn recursive_tag_union() {
# @<1>: [[hash_[Cons 2,Nil 0](0)]]
#Derived.hash_[Cons 2,Nil 0] =
\#Derived.hasher, #Derived.union ->
#Derived.discrHasher =
Hash.hash #Derived.hasher (@tag_discriminant #Derived.union)
when #Derived.union is
Cons #Derived.4 #Derived.5 ->
Hash.hash (Hash.hash #Derived.discrHasher #Derived.4) #Derived.5
Nil -> #Derived.discrHasher
Cons #Derived.3 #Derived.4 ->
Hash.hash
(Hash.hash (Hash.addU8 #Derived.hasher 0) #Derived.3)
#Derived.4
Nil -> Hash.addU8 #Derived.hasher 1
"###
)
})

View file

@ -1427,8 +1427,8 @@ mod hash {
TEST_HASHER,
),
RocList::from_slice(&[
0, 0, // A
1, 0, // B
0, // A
1, // B
]),
RocList<u8>
)
@ -1456,14 +1456,14 @@ mod hash {
TEST_HASHER,
),
RocList::from_slice(&[
0, 0, // A
1, 0, // B
2, 0, // C
3, 0, // D
4, 0, // E
5, 0, // F
6, 0, // G
7, 0, // H
0, // A
1, // B
2, // C
3, // D
4, // E
5, // F
6, // G
7, // H
]),
RocList<u8>
)
@ -1520,7 +1520,7 @@ mod hash {
TEST_HASHER,
),
RocList::from_slice(&[
0, 0, // Ok
1, // Ok
// A is skipped because it is a newtype
15, 23, 47
]),
@ -1558,11 +1558,11 @@ mod hash {
TEST_HASHER,
),
RocList::from_slice(&[
0, 0, // dicsr A
0, // dicsr A
15, 23, // payloads A
1, 0, // discr B
1, // discr B
37, // payloads B
2, 0, // discr C
2, // discr C
97, 98, 99 // payloads C
]),
RocList<u8>
@ -1593,9 +1593,9 @@ mod hash {
TEST_HASHER,
),
RocList::from_slice(&[
0, 0, 1, // Cons 1
0, 0, 2, // Cons 2
1, 0, // Nil
0, 1, // Cons 1
0, 2, // Cons 2
1, // Nil
]),
RocList<u8>
)