First pass at serializing the abilities store

This commit is contained in:
Ayaz Hafiz 2022-10-10 19:31:53 -05:00
parent 781d1a2642
commit 9131a55a72
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58

View file

@ -100,6 +100,7 @@ impl MemberSpecializationInfo<Resolved> {
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct SpecializationId(NonZeroU32);
static_assertions::assert_eq_size!(SpecializationId, Option<SpecializationId>);
@ -469,6 +470,14 @@ impl IAbilitiesStore<Resolved> {
pub fn get_resolved(&self, id: SpecializationId) -> Option<Symbol> {
self.resolved_specializations.get(&id).copied()
}
pub fn serialize(&self, writer: &mut impl std::io::Write) -> std::io::Result<usize> {
serialize::serialize(self, writer)
}
pub fn deserialize(bytes: &[u8]) -> Self {
serialize::deserialize(bytes)
}
}
impl IAbilitiesStore<Pending> {
@ -660,3 +669,565 @@ impl IAbilitiesStore<Pending> {
}
}
}
mod serialize {
use roc_collections::{MutMap, VecMap};
use roc_module::symbol::Symbol;
use roc_region::all::Region;
use roc_types::{
subs::{SubsSlice, Variable},
types::MemberImpl,
};
use super::{
AbilitiesStore, AbilityMemberData, ImplKey, MemberSpecializationInfo, Resolved,
ResolvedMemberType, SpecializationId,
};
use std::{
borrow::Borrow,
io::{self, Write},
};
#[repr(C)]
#[derive(Clone, Copy, Debug)]
struct Header {
members_of_ability: u64,
specialization_to_root: u64,
ability_members: u64,
declared_implementations: u64,
specializations: u64,
next_specialization_id: u64,
resolved_specializations: u64,
}
impl Header {
fn from_store(store: &AbilitiesStore) -> Self {
let AbilitiesStore {
members_of_ability,
specialization_to_root,
ability_members,
declared_implementations,
specializations,
next_specialization_id,
resolved_specializations,
} = store;
Self {
members_of_ability: members_of_ability.len() as _,
specialization_to_root: specialization_to_root.len() as _,
ability_members: ability_members.len() as _,
declared_implementations: declared_implementations.len() as _,
specializations: specializations.len() as _,
next_specialization_id: next_specialization_id.get() as _,
resolved_specializations: resolved_specializations.len() as _,
}
}
fn to_array(self) -> [u8; std::mem::size_of::<Self>()] {
// Safety: With repr(c) all fields are in order and properly aligned without padding.
unsafe { std::mem::transmute(self) }
}
fn from_array(array: [u8; std::mem::size_of::<Self>()]) -> Self {
// Safety: With repr(c) all fields are in order and properly aligned without padding.
unsafe { std::mem::transmute(array) }
}
}
pub(super) fn serialize(store: &AbilitiesStore, writer: &mut impl Write) -> io::Result<usize> {
let header = Header::from_store(store).to_array();
let written = header.len();
writer.write_all(&header)?;
let AbilitiesStore {
members_of_ability,
specialization_to_root,
ability_members,
declared_implementations,
specializations,
next_specialization_id: _, // written in the header
resolved_specializations,
} = store;
let written = serialize_members_of_ability(members_of_ability, writer, written)?;
let written = serialize_specializations_to_root(specialization_to_root, writer, written)?;
let written = serialize_ability_members(ability_members, writer, written)?;
let written =
serialize_declared_implementations(declared_implementations, writer, written)?;
let written = serialize_specializations(specializations, writer, written)?;
let written =
serialize_resolved_specializations(resolved_specializations, writer, written)?;
Ok(written)
}
pub(super) fn deserialize(bytes: &[u8]) -> AbilitiesStore {
let mut offset = 0;
let header_slice = &bytes[..std::mem::size_of::<Header>()];
offset += header_slice.len();
let header = Header::from_array(header_slice.try_into().unwrap());
let (members_of_ability, offset) =
deserialize_members_of_ability(bytes, header.members_of_ability as _, offset);
let (specialization_to_root, offset) =
deserialize_specialization_to_root(bytes, header.specialization_to_root as _, offset);
let (ability_members, offset) =
deserialize_ability_members(bytes, header.ability_members as _, offset);
let (declared_implementations, offset) = deserialize_declared_implementations(
bytes,
header.declared_implementations as _,
offset,
);
let (specializations, offset) =
deserialize_specializations(bytes, header.specializations as _, offset);
let (resolved_specializations, offset) = deserialize_resolved_specializations(
bytes,
header.resolved_specializations as _,
offset,
);
let _ = offset;
AbilitiesStore {
members_of_ability,
specialization_to_root,
ability_members,
declared_implementations,
specializations,
next_specialization_id: (header.next_specialization_id as u32).try_into().unwrap(),
resolved_specializations,
}
}
fn serialize_members_of_ability(
members_of_ability: &MutMap<Symbol, Vec<Symbol>>,
writer: &mut impl Write,
written: usize,
) -> io::Result<usize> {
serialize_map(
members_of_ability,
serialize_slice,
serialize_slice_of_slices,
writer,
written,
)
}
fn deserialize_members_of_ability(
bytes: &[u8],
length: usize,
offset: usize,
) -> (MutMap<Symbol, Vec<Symbol>>, usize) {
deserialize_map(
bytes,
deserialize_vec,
deserialize_slice_of_slices,
length,
offset,
)
}
#[repr(C)]
struct SerImplKey(Symbol, Symbol);
impl From<&ImplKey> for SerImplKey {
fn from(k: &ImplKey) -> Self {
Self(k.opaque, k.ability_member)
}
}
impl From<&SerImplKey> for ImplKey {
fn from(k: &SerImplKey) -> Self {
Self {
opaque: k.0,
ability_member: k.1,
}
}
}
fn serialize_specializations_to_root(
specialization_to_root: &MutMap<Symbol, ImplKey>,
writer: &mut impl Write,
written: usize,
) -> io::Result<usize> {
serialize_map(
specialization_to_root,
serialize_slice,
|keys, writer, written| {
serialize_slice(
&keys.iter().map(SerImplKey::from).collect::<Vec<_>>(),
writer,
written,
)
},
writer,
written,
)
}
fn deserialize_specialization_to_root(
bytes: &[u8],
length: usize,
offset: usize,
) -> (MutMap<Symbol, ImplKey>, usize) {
deserialize_map(
bytes,
deserialize_vec,
|bytes, length, offset| {
let (slice, offset) = deserialize_slice::<SerImplKey>(bytes, length, offset);
(slice.iter().map(ImplKey::from).collect(), offset)
},
length,
offset,
)
}
#[repr(C)]
struct SerMemberData(Symbol, Region, Variable);
impl From<&AbilityMemberData<Resolved>> for SerMemberData {
fn from(k: &AbilityMemberData<Resolved>) -> Self {
Self(k.parent_ability, k.region, k.typ.0)
}
}
impl From<&SerMemberData> for AbilityMemberData<Resolved> {
fn from(k: &SerMemberData) -> Self {
Self {
parent_ability: k.0,
region: k.1,
typ: ResolvedMemberType(k.2),
}
}
}
fn serialize_ability_members(
ability_members: &MutMap<Symbol, AbilityMemberData<Resolved>>,
writer: &mut impl Write,
written: usize,
) -> io::Result<usize> {
serialize_map(
ability_members,
serialize_slice,
|keys, writer, written| {
serialize_slice(
&keys.iter().map(SerMemberData::from).collect::<Vec<_>>(),
writer,
written,
)
},
writer,
written,
)
}
fn deserialize_ability_members(
bytes: &[u8],
length: usize,
offset: usize,
) -> (MutMap<Symbol, AbilityMemberData<Resolved>>, usize) {
deserialize_map(
bytes,
deserialize_vec,
|bytes, length, offset| {
let (slice, offset) = deserialize_slice::<SerMemberData>(bytes, length, offset);
(slice.iter().map(AbilityMemberData::from).collect(), offset)
},
length,
offset,
)
}
#[repr(C)]
enum SerMemberImpl {
Impl(Symbol),
Derived,
Error,
}
impl From<&MemberImpl> for SerMemberImpl {
fn from(k: &MemberImpl) -> Self {
match k {
MemberImpl::Impl(s) => Self::Impl(*s),
MemberImpl::Derived => Self::Derived,
MemberImpl::Error => Self::Error,
}
}
}
impl From<&SerMemberImpl> for MemberImpl {
fn from(k: &SerMemberImpl) -> Self {
match k {
SerMemberImpl::Impl(s) => Self::Impl(*s),
SerMemberImpl::Derived => Self::Derived,
SerMemberImpl::Error => Self::Error,
}
}
}
fn serialize_declared_implementations(
declared_implementations: &MutMap<ImplKey, MemberImpl>,
writer: &mut impl Write,
written: usize,
) -> io::Result<usize> {
serialize_map(
declared_implementations,
serialize_slice,
|keys, writer, written| {
serialize_slice(
&keys.iter().map(SerMemberImpl::from).collect::<Vec<_>>(),
writer,
written,
)
},
writer,
written,
)
}
fn deserialize_declared_implementations(
bytes: &[u8],
length: usize,
offset: usize,
) -> (MutMap<ImplKey, MemberImpl>, usize) {
deserialize_map(
bytes,
deserialize_vec,
|bytes, length, offset| {
let (slice, offset) = deserialize_slice::<SerMemberImpl>(bytes, length, offset);
(slice.iter().map(MemberImpl::from).collect(), offset)
},
length,
offset,
)
}
#[repr(C)]
struct SerMemberSpecInfo(Symbol, SubsSlice<u8>, SubsSlice<Variable>);
fn serialize_specializations(
specializations: &MutMap<Symbol, MemberSpecializationInfo<Resolved>>,
writer: &mut impl Write,
written: usize,
) -> io::Result<usize> {
serialize_map(
specializations,
serialize_slice,
|spec_info, writer, written| {
let mut spec_lambda_sets_regions: Vec<u8> = Vec::new();
let mut spec_lambda_sets_vars: Vec<Variable> = Vec::new();
let mut ser_member_spec_infos: Vec<SerMemberSpecInfo> = Vec::new();
for MemberSpecializationInfo {
_phase: _,
symbol,
specialization_lambda_sets,
} in spec_info
{
let regions = SubsSlice::extend_new(
&mut spec_lambda_sets_regions,
specialization_lambda_sets.keys().copied(),
);
let vars = SubsSlice::extend_new(
&mut spec_lambda_sets_vars,
specialization_lambda_sets.values().copied(),
);
ser_member_spec_infos.push(SerMemberSpecInfo(*symbol, regions, vars));
}
let written = serialize_slice(&ser_member_spec_infos, writer, written)?;
let written = serialize_slice(&spec_lambda_sets_regions, writer, written)?;
let written = serialize_slice(&spec_lambda_sets_vars, writer, written)?;
Ok(written)
},
writer,
written,
)
}
fn deserialize_specializations(
bytes: &[u8],
length: usize,
offset: usize,
) -> (MutMap<Symbol, MemberSpecializationInfo<Resolved>>, usize) {
deserialize_map(
bytes,
deserialize_vec,
|bytes, length, offset| {
let (serialized_slices, offset) =
deserialize_slice::<SerMemberSpecInfo>(bytes, length, offset);
let (regions_slice, offset) = {
let total_items = serialized_slices.iter().map(|s| s.1.len()).sum();
deserialize_slice::<u8>(bytes, total_items, offset)
};
let (vars_slice, offset) = {
let total_items = serialized_slices.iter().map(|s| s.2.len()).sum();
deserialize_slice::<Variable>(bytes, total_items, offset)
};
let mut spec_infos: Vec<MemberSpecializationInfo<Resolved>> =
Vec::with_capacity(length);
for SerMemberSpecInfo(symbol, regions, vars) in serialized_slices {
let regions = regions_slice[regions.indices()].to_vec();
let lset_vars = vars_slice[vars.indices()].to_vec();
let spec_info = MemberSpecializationInfo::new(*symbol, unsafe {
VecMap::zip(regions, lset_vars)
});
spec_infos.push(spec_info)
}
(spec_infos, offset)
},
length,
offset,
)
}
fn serialize_resolved_specializations(
resolved_specializations: &MutMap<SpecializationId, Symbol>,
writer: &mut impl Write,
written: usize,
) -> io::Result<usize> {
serialize_map(
resolved_specializations,
serialize_slice,
serialize_slice,
writer,
written,
)
}
fn deserialize_resolved_specializations(
bytes: &[u8],
length: usize,
offset: usize,
) -> (MutMap<SpecializationId, Symbol>, usize) {
deserialize_map(bytes, deserialize_vec, deserialize_vec, length, offset)
}
fn serialize_map<K: Clone, V: Clone, W: Write>(
map: &MutMap<K, V>,
ser_keys: fn(&[K], &mut W, usize) -> io::Result<usize>,
ser_values: fn(&[V], &mut W, usize) -> io::Result<usize>,
writer: &mut W,
written: usize,
) -> io::Result<usize> {
let keys = map.keys().cloned().collect::<Vec<_>>();
let values = map.values().cloned().collect::<Vec<_>>();
let written = ser_keys(keys.as_slice(), writer, written)?;
let written = ser_values(values.as_slice(), writer, written)?;
Ok(written)
}
#[allow(clippy::type_complexity)]
fn deserialize_map<K, V>(
bytes: &[u8],
deser_keys: fn(&[u8], usize, usize) -> (Vec<K>, usize),
deser_values: fn(&[u8], usize, usize) -> (Vec<V>, usize),
length: usize,
offset: usize,
) -> (MutMap<K, V>, usize)
where
K: Clone + std::hash::Hash + Eq,
V: Clone,
{
let (keys, offset) = deser_keys(bytes, length, offset);
let (values, offset) = deser_values(bytes, length, offset);
(
MutMap::from_iter((keys.iter().cloned()).zip(values.iter().cloned())),
offset,
)
}
fn serialize_slice_of_slices<'a, T, U>(
slice_of_slices: &[U],
writer: &mut impl Write,
written: usize,
) -> io::Result<usize>
where
T: 'a + Copy,
U: 'a + Borrow<[T]> + Sized,
{
let mut item_buf: Vec<T> = Vec::new();
let mut serialized_slices: Vec<SubsSlice<T>> = Vec::new();
for slice in slice_of_slices {
let slice = SubsSlice::extend_new(&mut item_buf, slice.borrow().iter().copied());
serialized_slices.push(slice);
}
let written = serialize_slice(&serialized_slices, writer, written)?;
serialize_slice(&item_buf, writer, written)
}
fn deserialize_slice_of_slices<T: Clone>(
bytes: &[u8],
length: usize,
offset: usize,
) -> (Vec<Vec<T>>, usize) {
let (serialized_slices, offset) = deserialize_slice::<SubsSlice<T>>(bytes, length, offset);
let (vars_slice, offset) = {
let total_items = serialized_slices.iter().map(|s| s.len()).sum();
deserialize_slice::<T>(bytes, total_items, offset)
};
let mut slice_of_slices = Vec::with_capacity(length);
for slice in serialized_slices {
let deserialized_slice = &vars_slice[slice.indices()];
slice_of_slices.push(deserialized_slice.to_vec())
}
(slice_of_slices, offset)
}
fn serialize_slice<'a, 'b, 'c, T: 'a>(
slice: &'b [T],
writer: &'c mut impl Write,
written: usize,
) -> std::io::Result<usize> {
let alignment = std::mem::align_of::<T>();
let padding_bytes = round_to_multiple_of(written, alignment) - written;
for _ in 0..padding_bytes {
writer.write_all(&[0])?;
}
let bytes_slice = unsafe { slice_as_bytes(slice) };
writer.write_all(bytes_slice)?;
Ok(written + padding_bytes + bytes_slice.len())
}
fn deserialize_slice<T>(bytes: &[u8], length: usize, mut offset: usize) -> (&[T], usize) {
let alignment = std::mem::align_of::<T>();
let size = std::mem::size_of::<T>();
offset = round_to_multiple_of(offset, alignment);
let byte_length = length * size;
let byte_slice = &bytes[offset..][..byte_length];
let slice = unsafe { std::slice::from_raw_parts(byte_slice.as_ptr() as *const T, length) };
(slice, offset + byte_length)
}
fn deserialize_vec<T: Clone>(bytes: &[u8], length: usize, offset: usize) -> (Vec<T>, usize) {
let (slice, offset) = deserialize_slice(bytes, length, offset);
(slice.to_vec(), offset)
}
unsafe fn slice_as_bytes<T>(slice: &[T]) -> &[u8] {
let ptr = slice.as_ptr();
let byte_length = std::mem::size_of::<T>() * slice.len();
std::slice::from_raw_parts(ptr as *const u8, byte_length)
}
fn round_to_multiple_of(value: usize, base: usize) -> usize {
(value + (base - 1)) / base * base
}
}