Get started with semantic layouts for tag unions

This commit is contained in:
Ayaz Hafiz 2023-05-10 18:47:10 -05:00
parent 24e65cbf8d
commit 8ca71c7eda
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
6 changed files with 95 additions and 29 deletions

View file

@ -13,8 +13,8 @@ use roc_problem::can::RuntimeError;
use roc_target::{PtrWidth, TargetInfo}; use roc_target::{PtrWidth, TargetInfo};
use roc_types::num::NumericRange; use roc_types::num::NumericRange;
use roc_types::subs::{ use roc_types::subs::{
self, Content, FlatType, GetSubsSlice, Label, OptVariable, RecordFields, Subs, TagExt, self, Content, FlatType, GetSubsSlice, OptVariable, RecordFields, Subs, TagExt, TupleElems,
TupleElems, UnsortedUnionLabels, Variable, VariableSubsSlice, UnsortedUnionLabels, Variable, VariableSubsSlice,
}; };
use roc_types::types::{ use roc_types::types::{
gather_fields_unsorted_iter, gather_tuple_elems_unsorted_iter, RecordField, RecordFieldsError, gather_fields_unsorted_iter, gather_tuple_elems_unsorted_iter, RecordField, RecordFieldsError,
@ -2505,6 +2505,10 @@ impl<'a> std::ops::Deref for Layout<'a> {
} }
impl<'a> LayoutRepr<'a> { impl<'a> LayoutRepr<'a> {
const UNIT: Self = LayoutRepr::struct_(&[]);
const BOOL: Self = LayoutRepr::Builtin(Builtin::Bool);
const U8: Self = LayoutRepr::Builtin(Builtin::Int(IntWidth::U8));
pub const fn struct_(field_layouts: &'a [InLayout<'a>]) -> Self { pub const fn struct_(field_layouts: &'a [InLayout<'a>]) -> Self {
Self::Struct { field_layouts } Self::Struct { field_layouts }
} }
@ -3630,9 +3634,38 @@ fn get_recursion_var(subs: &Subs, var: Variable) -> Option<Variable> {
} }
} }
trait Label: subs::Label + Ord + Clone + Into<TagOrClosure> {
fn semantic_repr<'a, 'r>(
arena: &'a Bump,
labels: impl ExactSizeIterator<Item = &'r Self>,
) -> SemanticRepr<'a>
where
Self: 'r;
}
impl Label for TagName {
fn semantic_repr<'a, 'r>(
arena: &'a Bump,
labels: impl ExactSizeIterator<Item = &'r Self>,
) -> SemanticRepr<'a> {
SemanticRepr::tag_union(
arena.alloc_slice_fill_iter(labels.map(|x| &*arena.alloc_str(x.0.as_str()))),
)
}
}
impl Label for Symbol {
fn semantic_repr<'a, 'r>(
arena: &'a Bump,
labels: impl ExactSizeIterator<Item = &'r Self>,
) -> SemanticRepr<'a> {
SemanticRepr::lambdas(arena.alloc_slice_fill_iter(labels.copied()))
}
}
fn union_sorted_non_recursive_tags_help<'a, L>( fn union_sorted_non_recursive_tags_help<'a, L>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
tags_list: &[(&'_ L, &[Variable])], tags_list: &mut Vec<'_, &'_ (&'_ L, &[Variable])>,
) -> Cacheable<UnionVariant<'a>> ) -> Cacheable<UnionVariant<'a>>
where where
L: Label + Ord + Clone + Into<TagOrClosure>, L: Label + Ord + Clone + Into<TagOrClosure>,
@ -3640,7 +3673,6 @@ where
let mut cache_criteria = CACHEABLE; let mut cache_criteria = CACHEABLE;
// sort up front; make sure the ordering stays intact! // sort up front; make sure the ordering stays intact!
let mut tags_list = Vec::from_iter_in(tags_list.iter(), env.arena);
tags_list.sort_unstable_by(|(a, _), (b, _)| a.cmp(b)); tags_list.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
match tags_list.len() { match tags_list.len() {
@ -3649,7 +3681,7 @@ where
Cacheable(UnionVariant::Never, cache_criteria) Cacheable(UnionVariant::Never, cache_criteria)
} }
1 => { 1 => {
let &(tag_name, arguments) = tags_list.remove(0); let &&(tag_name, arguments) = &tags_list[0];
let tag_name = tag_name.clone().into(); let tag_name = tag_name.clone().into();
// just one tag in the union (but with arguments) can be a struct // just one tag in the union (but with arguments) can be a struct
@ -3708,7 +3740,7 @@ where
let mut inhabited_tag_ids = BitVec::<usize>::repeat(true, num_tags); let mut inhabited_tag_ids = BitVec::<usize>::repeat(true, num_tags);
for &(tag_name, arguments) in tags_list.into_iter() { for &&(tag_name, arguments) in tags_list.iter() {
let mut arg_layouts = Vec::with_capacity_in(arguments.len() + 1, env.arena); let mut arg_layouts = Vec::with_capacity_in(arguments.len() + 1, env.arena);
for &var in arguments { for &var in arguments {
@ -4078,18 +4110,26 @@ where
return layout_from_newtype(env, tags); return layout_from_newtype(env, tags);
} }
let tags_vec = &tags.tags; let mut tags_vec = Vec::from_iter_in(tags.tags.iter(), env.arena);
let mut criteria = CACHEABLE; let mut criteria = CACHEABLE;
let variant = let variant =
union_sorted_non_recursive_tags_help(env, tags_vec).decompose(&mut criteria, env.subs); union_sorted_non_recursive_tags_help(env, &mut tags_vec).decompose(&mut criteria, env.subs);
let compute_semantic = || L::semantic_repr(env.arena, tags_vec.iter().map(|(l, _)| *l));
let result = match variant { let result = match variant {
Never => Layout::VOID, Never => Layout::VOID,
Unit => Layout::UNIT, Unit => env
BoolUnion { .. } => Layout::BOOL, .cache
ByteUnion(_) => Layout::U8, .put_in(Layout::new(LayoutRepr::UNIT, compute_semantic())),
BoolUnion { .. } => env
.cache
.put_in(Layout::new(LayoutRepr::BOOL, compute_semantic())),
ByteUnion(_) => env
.cache
.put_in(Layout::new(LayoutRepr::U8, compute_semantic())),
Newtype { Newtype {
arguments: field_layouts, arguments: field_layouts,
.. ..

View file

@ -1,5 +1,7 @@
//! Semantic representations of memory layouts for the purposes of specialization. //! Semantic representations of memory layouts for the purposes of specialization.
use roc_module::symbol::Symbol;
/// A semantic representation of a memory layout. /// A semantic representation of a memory layout.
/// Semantic representations describe the shape of a type a [Layout][super::Layout] is generated /// Semantic representations describe the shape of a type a [Layout][super::Layout] is generated
/// for. Semantic representations disambiguate types that have the same runtime memory layout, but /// for. Semantic representations disambiguate types that have the same runtime memory layout, but
@ -18,6 +20,8 @@ enum Inner<'a> {
None, None,
Record(SemaRecord<'a>), Record(SemaRecord<'a>),
Tuple(SemaTuple), Tuple(SemaTuple),
TagUnion(SemaTagUnion<'a>),
Lambdas(SemaLambdas<'a>),
} }
impl<'a> SemanticRepr<'a> { impl<'a> SemanticRepr<'a> {
@ -31,6 +35,14 @@ impl<'a> SemanticRepr<'a> {
pub(super) fn tuple(size: usize) -> Self { pub(super) fn tuple(size: usize) -> Self {
Self(Inner::Tuple(SemaTuple { size })) Self(Inner::Tuple(SemaTuple { size }))
} }
pub(super) fn tag_union(tags: &'a [&'a str]) -> Self {
Self(Inner::TagUnion(SemaTagUnion { tags }))
}
pub(super) fn lambdas(lambdas: &'a [Symbol]) -> Self {
Self(Inner::Lambdas(SemaLambdas { lambdas }))
}
} }
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
@ -42,3 +54,13 @@ struct SemaRecord<'a> {
struct SemaTuple { struct SemaTuple {
size: usize, size: usize,
} }
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
struct SemaTagUnion<'a> {
tags: &'a [&'a str],
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
struct SemaLambdas<'a> {
lambdas: &'a [Symbol],
}

View file

@ -58,7 +58,7 @@ procedure Test.11 (Test.53, Test.54):
joinpoint Test.27 Test.12 #Attr.12: joinpoint Test.27 Test.12 #Attr.12:
let Test.8 : Int1 = UnionAtIndex (Id 2) (Index 1) #Attr.12; let Test.8 : Int1 = UnionAtIndex (Id 2) (Index 1) #Attr.12;
let Test.7 : [<rnw><null>, C *self Int1, C *self Int1] = UnionAtIndex (Id 2) (Index 0) #Attr.12; let Test.7 : [<rnw><null>, C *self Int1, C *self Int1] = UnionAtIndex (Id 2) (Index 0) #Attr.12;
joinpoint #Derived_gen.5: joinpoint #Derived_gen.3:
joinpoint Test.31 Test.29: joinpoint Test.31 Test.29:
let Test.30 : U8 = GetTagId Test.7; let Test.30 : U8 = GetTagId Test.7;
switch Test.30: switch Test.30:
@ -85,14 +85,14 @@ procedure Test.11 (Test.53, Test.54):
jump Test.31 Test.32; jump Test.31 Test.32;
in in
let #Derived_gen.6 : Int1 = lowlevel RefCountIsUnique #Attr.12; let #Derived_gen.4 : Int1 = lowlevel RefCountIsUnique #Attr.12;
if #Derived_gen.6 then if #Derived_gen.4 then
decref #Attr.12; decref #Attr.12;
jump #Derived_gen.5; jump #Derived_gen.3;
else else
inc Test.7; inc Test.7;
decref #Attr.12; decref #Attr.12;
jump #Derived_gen.5; jump #Derived_gen.3;
in in
jump Test.27 Test.53 Test.54; jump Test.27 Test.53 Test.54;
@ -125,7 +125,7 @@ procedure Test.6 (Test.7, Test.8, Test.5):
procedure Test.9 (Test.10, #Attr.12): procedure Test.9 (Test.10, #Attr.12):
let Test.8 : Int1 = UnionAtIndex (Id 1) (Index 1) #Attr.12; let Test.8 : Int1 = UnionAtIndex (Id 1) (Index 1) #Attr.12;
let Test.7 : [<rnw><null>, C *self Int1, C *self Int1] = UnionAtIndex (Id 1) (Index 0) #Attr.12; let Test.7 : [<rnw><null>, C *self Int1, C *self Int1] = UnionAtIndex (Id 1) (Index 0) #Attr.12;
joinpoint #Derived_gen.3: joinpoint #Derived_gen.5:
let Test.37 : U8 = GetTagId Test.7; let Test.37 : U8 = GetTagId Test.7;
joinpoint Test.38 Test.36: joinpoint Test.38 Test.36:
switch Test.8: switch Test.8:
@ -153,14 +153,14 @@ procedure Test.9 (Test.10, #Attr.12):
jump Test.38 Test.39; jump Test.38 Test.39;
in in
let #Derived_gen.4 : Int1 = lowlevel RefCountIsUnique #Attr.12; let #Derived_gen.6 : Int1 = lowlevel RefCountIsUnique #Attr.12;
if #Derived_gen.4 then if #Derived_gen.6 then
decref #Attr.12; decref #Attr.12;
jump #Derived_gen.3; jump #Derived_gen.5;
else else
inc Test.7; inc Test.7;
decref #Attr.12; decref #Attr.12;
jump #Derived_gen.3; jump #Derived_gen.5;
procedure Test.0 (): procedure Test.0 ():
let Test.41 : Int1 = false; let Test.41 : Int1 = false;

View file

@ -212,8 +212,8 @@ procedure Json.42 (Json.298):
let Json.496 : U64 = 1i64; let Json.496 : U64 = 1i64;
let Json.495 : {List U8, List U8} = CallByName List.52 Json.304 Json.496; let Json.495 : {List U8, List U8} = CallByName List.52 Json.304 Json.496;
let Json.309 : List U8 = StructAtIndex 1 Json.495; let Json.309 : List U8 = StructAtIndex 1 Json.495;
let #Derived_gen.0 : List U8 = StructAtIndex 0 Json.495; let #Derived_gen.1 : List U8 = StructAtIndex 0 Json.495;
dec #Derived_gen.0; dec #Derived_gen.1;
let Json.494 : [C {}, C Str] = TagId(1) Json.307; let Json.494 : [C {}, C Str] = TagId(1) Json.307;
let Json.493 : {List U8, [C {}, C Str]} = Struct {Json.309, Json.494}; let Json.493 : {List U8, [C {}, C Str]} = Struct {Json.309, Json.494};
ret Json.493; ret Json.493;
@ -347,8 +347,8 @@ procedure Str.9 (Str.79):
else else
let Str.300 : U8 = StructAtIndex 3 Str.80; let Str.300 : U8 = StructAtIndex 3 Str.80;
let Str.301 : U64 = StructAtIndex 0 Str.80; let Str.301 : U64 = StructAtIndex 0 Str.80;
let #Derived_gen.1 : Str = StructAtIndex 1 Str.80; let #Derived_gen.0 : Str = StructAtIndex 1 Str.80;
dec #Derived_gen.1; dec #Derived_gen.0;
let Str.299 : {U64, U8} = Struct {Str.301, Str.300}; let Str.299 : {U64, U8} = Struct {Str.301, Str.300};
let Str.298 : [C {U64, U8}, C Str] = TagId(0) Str.299; let Str.298 : [C {U64, U8}, C Str] = TagId(0) Str.299;
ret Str.298; ret Str.298;

View file

@ -190,8 +190,8 @@ procedure Json.42 (Json.298):
let Json.496 : U64 = 1i64; let Json.496 : U64 = 1i64;
let Json.495 : {List U8, List U8} = CallByName List.52 Json.304 Json.496; let Json.495 : {List U8, List U8} = CallByName List.52 Json.304 Json.496;
let Json.309 : List U8 = StructAtIndex 1 Json.495; let Json.309 : List U8 = StructAtIndex 1 Json.495;
let #Derived_gen.0 : List U8 = StructAtIndex 0 Json.495; let #Derived_gen.1 : List U8 = StructAtIndex 0 Json.495;
dec #Derived_gen.0; dec #Derived_gen.1;
let Json.494 : [C {}, C Str] = TagId(1) Json.307; let Json.494 : [C {}, C Str] = TagId(1) Json.307;
let Json.493 : {List U8, [C {}, C Str]} = Struct {Json.309, Json.494}; let Json.493 : {List U8, [C {}, C Str]} = Struct {Json.309, Json.494};
ret Json.493; ret Json.493;
@ -345,8 +345,8 @@ procedure Str.9 (Str.79):
else else
let Str.314 : U8 = StructAtIndex 3 Str.80; let Str.314 : U8 = StructAtIndex 3 Str.80;
let Str.315 : U64 = StructAtIndex 0 Str.80; let Str.315 : U64 = StructAtIndex 0 Str.80;
let #Derived_gen.1 : Str = StructAtIndex 1 Str.80; let #Derived_gen.0 : Str = StructAtIndex 1 Str.80;
dec #Derived_gen.1; dec #Derived_gen.0;
let Str.313 : {U64, U8} = Struct {Str.315, Str.314}; let Str.313 : {U64, U8} = Struct {Str.315, Str.314};
let Str.312 : [C {U64, U8}, C Str] = TagId(0) Str.313; let Str.312 : [C {U64, U8}, C Str] = TagId(0) Str.313;
ret Str.312; ret Str.312;

View file

@ -2,6 +2,10 @@ procedure Test.1 (Test.4):
let Test.12 : Int1 = false; let Test.12 : Int1 = false;
ret Test.12; ret Test.12;
procedure Test.1 (Test.4):
let Test.14 : Int1 = false;
ret Test.14;
procedure Test.2 (Test.5, Test.6): procedure Test.2 (Test.5, Test.6):
let Test.10 : U8 = 18i64; let Test.10 : U8 = 18i64;
ret Test.10; ret Test.10;