diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index 26fff65873..03ceb982fd 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -724,6 +724,14 @@ impl<'a, 'b> Env<'a, 'b> { } } +const fn round_up_to_alignment(width: u32, alignment: u32) -> u32 { + if alignment != 0 && width % alignment > 0 { + width + alignment - (width % alignment) + } else { + width + } +} + impl<'a> Layout<'a> { fn new_help<'b>( env: &mut Env<'a, 'b>, @@ -859,11 +867,7 @@ impl<'a> Layout<'a> { let width = self.stack_size_without_alignment(pointer_size); let alignment = self.alignment_bytes(pointer_size); - if alignment != 0 && width % alignment > 0 { - width + alignment - (width % alignment) - } else { - width - } + round_up_to_alignment(width, alignment) } fn stack_size_without_alignment(&self, pointer_size: u32) -> u32 { @@ -885,6 +889,8 @@ impl<'a> Layout<'a> { match variant { NonRecursive(fields) => { + let tag_id_builtin = variant.tag_id_builtin(); + fields .iter() .map(|tag_layout| { @@ -894,9 +900,10 @@ impl<'a> Layout<'a> { .sum::() }) .max() + .map(|w| round_up_to_alignment(w, tag_id_builtin.alignment_bytes(pointer_size))) .unwrap_or_default() // the size of the tag_id - + variant.tag_id_builtin().stack_size(pointer_size) + + tag_id_builtin.stack_size(pointer_size) } Recursive(_) @@ -924,13 +931,22 @@ impl<'a> Layout<'a> { use UnionLayout::*; match variant { - NonRecursive(tags) => tags - .iter() - .map(|x| x.iter()) - .flatten() - .map(|x| x.alignment_bytes(pointer_size)) - .max() - .unwrap_or(0), + NonRecursive(tags) => { + let tag_id_builtin = variant.tag_id_builtin(); + + tags.iter() + .map(|x| x.iter()) + .flatten() + .map(|x| x.alignment_bytes(pointer_size)) + .max() + .map(|w| { + round_up_to_alignment( + w, + tag_id_builtin.alignment_bytes(pointer_size), + ) + }) + .unwrap_or(0) + } Recursive(_) | NullableWrapped { .. } | NullableUnwrapped { .. } diff --git a/compiler/test_gen/src/gen_tags.rs b/compiler/test_gen/src/gen_tags.rs index adcfcf8bec..fe50a16cde 100644 --- a/compiler/test_gen/src/gen_tags.rs +++ b/compiler/test_gen/src/gen_tags.rs @@ -5,6 +5,23 @@ use crate::assert_evals_to; use indoc::indoc; use roc_std::{RocList, RocStr}; +#[test] +fn width_and_alignment_u8_u8() { + use roc_mono::layout::Builtin; + use roc_mono::layout::Layout; + use roc_mono::layout::UnionLayout; + + let t = &[Layout::Builtin(Builtin::Int8)] as &[_]; + let tt = [t, t]; + + let layout = Layout::Union(UnionLayout::NonRecursive(&tt)); + + // at the moment, the tag id uses an I64, so + let ptr_width = 8; + assert_eq!(layout.alignment_bytes(ptr_width), 8); + assert_eq!(layout.stack_size(ptr_width), 16); +} + #[test] fn applied_tag_nothing_ir() { assert_evals_to!(